Skip to content

Commit

Permalink
fix[next][dace]: Use constant shape for neighbor tables in local dime…
Browse files Browse the repository at this point in the history
…nsion (#1422)

Main purpose of this PR is to avoid the definition of shape symbols for array dimensions known at compile time. The local size of neighbor connectivity tables falls into this category. For each element in the origin dimension, the number of elements in the target dimension is defined by the attribute max_neighbors in the offset provider.
  • Loading branch information
edopao authored Jan 25, 2024
1 parent 11f9c1c commit ac0478a
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 83 deletions.
96 changes: 58 additions & 38 deletions src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import hashlib
import warnings
from inspect import currentframe, getframeinfo
from pathlib import Path
from typing import Any, Mapping, Optional, Sequence

import dace
Expand All @@ -26,7 +27,7 @@
import gt4py.next.program_processors.otf_compile_executor as otf_exec
import gt4py.next.program_processors.processor_interface as ppi
from gt4py.next import common
from gt4py.next.iterator import embedded as itir_embedded, transforms as itir_transforms
from gt4py.next.iterator import transforms as itir_transforms
from gt4py.next.otf.compilation import cache as compilation_cache
from gt4py.next.type_system import type_specifications as ts, type_translation

Expand Down Expand Up @@ -109,23 +110,29 @@ def _ensure_is_on_device(


def get_connectivity_args(
neighbor_tables: Sequence[tuple[str, itir_embedded.NeighborTableOffsetProvider]],
neighbor_tables: Mapping[str, common.NeighborTable],
device: dace.dtypes.DeviceType,
) -> dict[str, Any]:
return {
connectivity_identifier(offset): _ensure_is_on_device(table.table, device)
for offset, table in neighbor_tables
connectivity_identifier(offset): _ensure_is_on_device(offset_provider.table, device)
for offset, offset_provider in neighbor_tables.items()
}


def get_shape_args(
arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any]
) -> Mapping[str, int]:
return {
str(sym): size
for name, value in args.items()
for sym, size in zip(arrays[name].shape, value.shape)
}
shape_args: dict[str, int] = {}
for name, value in args.items():
for sym, size in zip(arrays[name].shape, value.shape):
if isinstance(sym, dace.symbol):
assert sym.name not in shape_args
shape_args[sym.name] = size
elif sym != size:
raise RuntimeError(
f"Expected shape {arrays[name].shape} for arg {name}, got {value.shape}."
)
return shape_args


def get_offset_args(
Expand Down Expand Up @@ -158,34 +165,41 @@ def get_stride_args(
return stride_args


_build_cache_cpu: dict[str, CompiledSDFG] = {}
_build_cache_gpu: dict[str, CompiledSDFG] = {}
_build_cache: dict[str, CompiledSDFG] = {}


def get_cache_id(
build_type: str,
build_for_gpu: bool,
program: itir.FencilDefinition,
arg_types: Sequence[ts.TypeSpec],
column_axis: Optional[common.Dimension],
offset_provider: Mapping[str, Any],
) -> str:
max_neighbors = [
(k, v.max_neighbors)
for k, v in offset_provider.items()
if isinstance(
v,
(
itir_embedded.NeighborTableOffsetProvider,
itir_embedded.StridedNeighborOffsetProvider,
),
)
def offset_invariants(offset):
if isinstance(offset, common.Connectivity):
return (
offset.origin_axis,
offset.neighbor_axis,
offset.has_skip_values,
offset.max_neighbors,
)
if isinstance(offset, common.Dimension):
return (offset,)
return tuple()

offset_cache_keys = [
(name, *offset_invariants(offset)) for name, offset in offset_provider.items()
]
cache_id_args = [
str(arg)
for arg in (
build_type,
build_for_gpu,
program,
*arg_types,
column_axis,
*max_neighbors,
*offset_cache_keys,
)
]
m = hashlib.sha256()
Expand Down Expand Up @@ -262,7 +276,7 @@ def build_sdfg_from_itir(
# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider, lift_mode)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
sdfg = sdfg_genenerator.visit(program)
sdfg: dace.SDFG = sdfg_genenerator.visit(program)
if sdfg is None:
raise RuntimeError(f"Visit failed for program {program.id}.")

Expand All @@ -284,8 +298,8 @@ def build_sdfg_from_itir(

# run DaCe auto-optimization heuristics
if auto_optimize:
# TODO: Investigate how symbol definitions improve autoopt transformations,
# in which case the cache table should take the symbols map into account.
# TODO: Investigate performance improvement from SDFG specialization with constant symbols,
# for array shape and strides, although this would imply JIT compilation.
symbols: dict[str, int] = {}
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu)
Expand All @@ -307,25 +321,31 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
# ITIR parameters
column_axis = kwargs.get("column_axis", None)
offset_provider = kwargs["offset_provider"]
# debug option to store SDFGs on filesystem and skip lowering ITIR to SDFG at each run
skip_itir_lowering_to_sdfg = kwargs.get("skip_itir_lowering_to_sdfg", False)

arg_types = [type_translation.from_value(arg) for arg in args]

cache_id = get_cache_id(program, arg_types, column_axis, offset_provider)
cache_id = get_cache_id(build_type, on_gpu, program, arg_types, column_axis, offset_provider)
if build_cache is not None and cache_id in build_cache:
# retrieve SDFG program from build cache
sdfg_program = build_cache[cache_id]
sdfg = sdfg_program.sdfg

else:
sdfg = build_sdfg_from_itir(
program,
*args,
offset_provider=offset_provider,
auto_optimize=auto_optimize,
on_gpu=on_gpu,
column_axis=column_axis,
lift_mode=lift_mode,
)
sdfg_filename = f"_dacegraphs/gt4py/{cache_id}/{program.id}.sdfg"
if not (skip_itir_lowering_to_sdfg and Path(sdfg_filename).exists()):
sdfg = build_sdfg_from_itir(
program,
*args,
offset_provider=offset_provider,
auto_optimize=auto_optimize,
on_gpu=on_gpu,
column_axis=column_axis,
lift_mode=lift_mode,
)
sdfg.save(sdfg_filename)
else:
sdfg = dace.SDFG.from_file(sdfg_filename)

sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache"
with dace.config.temporary_config():
Expand Down Expand Up @@ -361,7 +381,7 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
program,
*args,
**kwargs,
build_cache=_build_cache_cpu,
build_cache=_build_cache,
build_type=_build_type,
compiler_args=compiler_args,
on_gpu=False,
Expand All @@ -380,7 +400,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
program,
*args,
**kwargs,
build_cache=_build_cache_gpu,
build_cache=_build_cache,
build_type=_build_type,
on_gpu=True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
from typing import Any, Optional, cast
from typing import Any, Mapping, Optional, Sequence, cast

import dace

import gt4py.eve as eve
from gt4py.next import Dimension, DimensionKind, type_inference as next_typing
from gt4py.next.common import NeighborTable
from gt4py.next.iterator import ir as itir, type_inference as itir_typing
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
from gt4py.next.iterator.ir import Expr, FunCall, Literal, SymRef
from gt4py.next.type_system import type_specifications as ts, type_translation

Expand All @@ -43,13 +43,12 @@
flatten_list,
get_sorted_dims,
map_nested_sdfg_symbols,
new_array_symbols,
unique_name,
unique_var_name,
)


def get_scan_args(stencil: Expr) -> tuple[bool, Literal]:
def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]:
"""
Parse stencil expression to extract the scan arguments.
Expand All @@ -68,7 +67,7 @@ def get_scan_args(stencil: Expr) -> tuple[bool, Literal]:
return is_forward.value == "True", init_carry


def get_scan_dim(
def _get_scan_dim(
column_axis: Dimension,
storage_types: dict[str, ts.TypeSpec],
output: SymRef,
Expand All @@ -93,6 +92,35 @@ def get_scan_dim(
)


def _make_array_shape_and_strides(
name: str,
dims: Sequence[Dimension],
neighbor_tables: Mapping[str, NeighborTable],
sort_dims: bool,
) -> tuple[list[dace.symbol], list[dace.symbol]]:
"""
Parse field dimensions and allocate symbols for array shape and strides.
For local dimensions, the size is known at compile-time and therefore
the corresponding array shape dimension is set to an integer literal value.
Returns
-------
tuple(shape, strides)
The output tuple fields are arrays of dace symbolic expressions.
"""
dtype = dace.int64
sorted_dims = [dim for _, dim in get_sorted_dims(dims)] if sort_dims else dims
shape = [
neighbor_tables[dim.value].max_neighbors
if dim.kind == DimensionKind.LOCAL
else dace.symbol(unique_name(f"{name}_shape{i}"), dtype)
for i, dim in enumerate(sorted_dims)
]
strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i, _ in enumerate(shape)]
return shape, strides


class ItirToSDFG(eve.NodeVisitor):
param_types: list[ts.TypeSpec]
storage_types: dict[str, ts.TypeSpec]
Expand All @@ -104,17 +132,27 @@ class ItirToSDFG(eve.NodeVisitor):
def __init__(
self,
param_types: list[ts.TypeSpec],
offset_provider: dict[str, NeighborTableOffsetProvider],
offset_provider: dict[str, NeighborTable],
column_axis: Optional[Dimension] = None,
):
self.param_types = param_types
self.column_axis = column_axis
self.offset_provider = offset_provider
self.storage_types = {}

def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True):
def add_storage(
self,
sdfg: dace.SDFG,
name: str,
type_: ts.TypeSpec,
neighbor_tables: Mapping[str, NeighborTable],
has_offset: bool = True,
sort_dimensions: bool = True,
):
if isinstance(type_, ts.FieldType):
shape, strides = new_array_symbols(name, len(type_.dims))
shape, strides = _make_array_shape_and_strides(
name, type_.dims, neighbor_tables, sort_dimensions
)
offset = (
[dace.symbol(unique_name(f"{name}_offset{i}_")) for i in range(len(type_.dims))]
if has_offset
Expand Down Expand Up @@ -153,14 +191,23 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):

# Add program parameters as SDFG storages.
for param, type_ in zip(node.params, self.param_types):
self.add_storage(program_sdfg, str(param.id), type_)
self.add_storage(program_sdfg, str(param.id), type_, neighbor_tables)

# Add connectivities as SDFG storages.
for offset, table in neighbor_tables:
scalar_kind = type_translation.get_scalar_kind(table.table.dtype)
local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL)
type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind))
self.add_storage(program_sdfg, connectivity_identifier(offset), type_, has_offset=False)
for offset, offset_provider in neighbor_tables.items():
scalar_kind = type_translation.get_scalar_kind(offset_provider.table.dtype)
local_dim = Dimension(offset, kind=DimensionKind.LOCAL)
type_ = ts.FieldType(
[offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind)
)
self.add_storage(
program_sdfg,
connectivity_identifier(offset),
type_,
neighbor_tables,
has_offset=False,
sort_dimensions=False,
)

# Create a nested SDFG for all stencil closures.
for closure in node.closures:
Expand Down Expand Up @@ -222,7 +269,7 @@ def visit_StencilClosure(

input_names = [str(inp.id) for inp in node.inputs]
neighbor_tables = filter_neighbor_tables(self.offset_provider)
connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]

output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state)
output_names = [k for k, _ in output_nodes.items()]
Expand Down Expand Up @@ -400,11 +447,11 @@ def _visit_scan_stencil_closure(
output_name: str,
) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], int]:
# extract scan arguments
is_forward, init_carry_value = get_scan_args(node.stencil)
is_forward, init_carry_value = _get_scan_args(node.stencil)
# select the scan dimension based on program argument for column axis
assert self.column_axis
assert isinstance(node.output, SymRef)
scan_dim, scan_dim_index, scan_dtype = get_scan_dim(
scan_dim, scan_dim_index, scan_dtype = _get_scan_dim(
self.column_axis,
self.storage_types,
node.output,
Expand Down Expand Up @@ -570,7 +617,7 @@ def _visit_parallel_stencil_closure(
) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]:
neighbor_tables = filter_neighbor_tables(self.offset_provider)
input_names = [str(inp.id) for inp in node.inputs]
conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]

# find the scan dimension, same as output dimension, and exclude it from the map domain
map_ranges = {}
Expand All @@ -583,7 +630,7 @@ def _visit_parallel_stencil_closure(
index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain}

input_arrays = [(name, self.storage_types[name]) for name in input_names]
connectivity_arrays = [(array_table[name], name) for name in conn_names]
connectivity_arrays = [(array_table[name], name) for name in connectivity_names]

context, results = closure_to_tasklet_sdfg(
node,
Expand Down
Loading

0 comments on commit ac0478a

Please sign in to comment.