Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into extend_astype_with_…
Browse files Browse the repository at this point in the history
…tuples

Conflicts:
	src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py
	src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py
	tests/next_tests/exclusion_matrices.py
	tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py
  • Loading branch information
Nina Burgdorfer authored and Nina Burgdorfer committed Oct 19, 2023
2 parents 869ae60 + f96ead5 commit 3823227
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 84 deletions.
163 changes: 142 additions & 21 deletions src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,44 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from typing import Any, Mapping, Sequence
import hashlib
from typing import Any, Mapping, Optional, Sequence

import dace
import numpy as np
from dace.codegen.compiled_sdfg import CompiledSDFG
from dace.transformation.auto import auto_optimize as autoopt

import gt4py.next.iterator.ir as itir
from gt4py.next import common
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
from gt4py.next.common import Dimension, Domain, UnitRange, is_field
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider
from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms
from gt4py.next.otf.compilation import cache
from gt4py.next.program_processors.processor_interface import program_executor
from gt4py.next.type_system import type_translation
from gt4py.next.type_system import type_specifications as ts, type_translation

from .itir_to_sdfg import ItirToSDFG
from .utility import connectivity_identifier, filter_neighbor_tables
from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims


def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]:
sorted_dims = get_sorted_dims(domain.dims)
return [domain.ranges[dim_index] for dim_index, _ in sorted_dims]


""" Default build configuration in DaCe backend """
_build_type = "Release"
# removing -ffast-math from DaCe default compiler args in order to support isfinite/isinf/isnan built-ins
_cpu_args = (
"-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -Wno-unused-parameter -Wno-unused-label"
)


def convert_arg(arg: Any):
if common.is_field(arg):
sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value)
if is_field(arg):
sorted_dims = get_sorted_dims(arg.domain.dims)
ndim = len(sorted_dims)
dim_indices = [dim[0] for dim in sorted_dims]
dim_indices = [dim_index for dim_index, _ in sorted_dims]
assert isinstance(arg.ndarray, np.ndarray)
return np.moveaxis(arg.ndarray, range(ndim), dim_indices)
return arg
Expand Down Expand Up @@ -69,6 +84,17 @@ def get_shape_args(
}


def get_offset_args(
arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any]
) -> Mapping[str, int]:
return {
str(sym): -drange.start
for param, arg in zip(params, args)
if is_field(arg)
for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain))
}


def get_stride_args(
arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any]
) -> Mapping[str, int]:
Expand All @@ -85,44 +111,139 @@ def get_stride_args(
return stride_args


_build_cache_cpu: dict[str, CompiledSDFG] = {}
_build_cache_gpu: dict[str, CompiledSDFG] = {}


def get_cache_id(
program: itir.FencilDefinition,
arg_types: Sequence[ts.TypeSpec],
column_axis: Optional[Dimension],
offset_provider: Mapping[str, Any],
) -> str:
max_neighbors = [
(k, v.max_neighbors)
for k, v in offset_provider.items()
if isinstance(v, (NeighborTableOffsetProvider, StridedNeighborOffsetProvider))
]
cache_id_args = [
str(arg)
for arg in (
program,
*arg_types,
column_axis,
*max_neighbors,
)
]
m = hashlib.sha256()
for s in cache_id_args:
m.update(s.encode())
return m.hexdigest()


@program_executor
def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
# build parameters
auto_optimize = kwargs.get("auto_optimize", False)
build_type = kwargs.get("build_type", "RelWithDebInfo")
run_on_gpu = kwargs.get("run_on_gpu", False)
build_cache = kwargs.get("build_cache", None)
# ITIR parameters
column_axis = kwargs.get("column_axis", None)
offset_provider = kwargs["offset_provider"]
neighbor_tables = filter_neighbor_tables(offset_provider)

program = preprocess_program(program, offset_provider)
arg_types = [type_translation.from_value(arg) for arg in args]
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
sdfg: dace.SDFG = sdfg_genenerator.visit(program)
sdfg.simplify()
neighbor_tables = filter_neighbor_tables(offset_provider)

cache_id = get_cache_id(program, arg_types, column_axis, offset_provider)
if build_cache is not None and cache_id in build_cache:
# retrieve SDFG program from build cache
sdfg_program = build_cache[cache_id]
sdfg = sdfg_program.sdfg
else:
# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
sdfg = sdfg_genenerator.visit(program)
sdfg.simplify()

# set array storage for GPU execution
if run_on_gpu:
device = dace.DeviceType.GPU
sdfg._name = f"{sdfg.name}_gpu"
for _, _, array in sdfg.arrays_recursive():
if not array.transient:
array.storage = dace.dtypes.StorageType.GPU_Global
else:
device = dace.DeviceType.CPU

# run DaCe auto-optimization heuristics
if auto_optimize:
# TODO Investigate how symbol definitions improve autoopt transformations,
# in which case the cache table should take the symbols map into account.
symbols: dict[str, int] = {}
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols)

# compile SDFG and retrieve SDFG program
sdfg.build_folder = cache._session_cache_dir_path / ".dacecache"
with dace.config.temporary_config():
dace.config.Config.set("compiler", "build_type", value=build_type)
dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args)
sdfg_program = sdfg.compile(validate=False)

# store SDFG program in build cache
if build_cache is not None:
build_cache[cache_id] = sdfg_program

dace_args = get_args(program.params, args)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
dace_conn_args = get_connectivity_args(neighbor_tables)
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args)
dace_strides = get_stride_args(sdfg.arrays, dace_field_args)
dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args)

sdfg.build_folder = cache._session_cache_dir_path / ".dacecache"
dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args)
dace_offsets = get_offset_args(sdfg.arrays, program.params, args)

all_args = {
**dace_args,
**dace_conn_args,
**dace_shapes,
**dace_conn_shapes,
**dace_strides,
**dace_conn_stirdes,
**dace_conn_strides,
**dace_offsets,
}
expected_args = {
key: value
for key, value in all_args.items()
if key in sdfg.signature_arglist(with_types=False)
}

with dace.config.temporary_config():
dace.config.Config.set("compiler", "allow_view_arguments", value=True)
dace.config.Config.set("compiler", "build_type", value="Debug")
dace.config.Config.set("compiler", "cpu", "args", value="-O0")
dace.config.Config.set("frontend", "check_args", value=True)
sdfg(**expected_args)
sdfg_program(**expected_args)


@program_executor
def run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
run_dace_iterator(
program,
*args,
**kwargs,
build_cache=_build_cache_cpu,
build_type=_build_type,
run_on_gpu=False,
)


@program_executor
def run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
run_dace_iterator(
program,
*args,
**kwargs,
build_cache=_build_cache_gpu,
build_type=_build_type,
run_on_gpu=True,
)
Loading

0 comments on commit 3823227

Please sign in to comment.