Skip to content

Commit

Permalink
Review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jan 25, 2024
1 parent 9f16cdd commit de4de8e
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import hashlib
import warnings
from inspect import currentframe, getframeinfo
from pathlib import Path
from typing import Any, Mapping, Optional, Sequence

import dace
Expand Down Expand Up @@ -329,12 +330,12 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
sdfg = sdfg_program.sdfg

else:
# debug: test large icon4py stencils without regenerating the SDFG at each run
generate_sdfg = True
# useful for debug: run gt4py program without regenerating the SDFG at each run
skip_itir_lowering_to_sdfg = False
target = "gpu" if on_gpu else "cpu"
sdfg_filename = f"_dacegraphs/{target}/{program.id}.sdfg"
sdfg_filename = f"_dacegraphs/gt4py/{target}/{program.id}.sdfg"

if generate_sdfg:
if not (skip_itir_lowering_to_sdfg and Path(sdfg_filename).exists()):
sdfg = build_sdfg_from_itir(
program,
*args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
from typing import Any, Mapping, Optional, cast
from typing import Any, Mapping, Optional, Sequence, cast

import dace

Expand Down Expand Up @@ -42,14 +42,13 @@
filter_neighbor_tables,
flatten_list,
get_sorted_dims,
map_field_dimensions_to_sdfg_symbols,
map_nested_sdfg_symbols,
unique_name,
unique_var_name,
)


def get_scan_args(stencil: Expr) -> tuple[bool, Literal]:
def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]:
"""
Parse stencil expression to extract the scan arguments.
Expand All @@ -68,7 +67,7 @@ def get_scan_args(stencil: Expr) -> tuple[bool, Literal]:
return is_forward.value == "True", init_carry


def get_scan_dim(
def _get_scan_dim(
column_axis: Dimension,
storage_types: dict[str, ts.TypeSpec],
output: SymRef,
Expand All @@ -93,6 +92,35 @@ def get_scan_dim(
)


def _make_array_shape_and_strides(
name: str,
dims: Sequence[Dimension],
neighbor_tables: Mapping[str, NeighborTable],
sort_dims: bool,
) -> tuple[list[dace.symbol], list[dace.symbol]]:
"""
Parse field dimensions and allocate symbols for array shape and strides.
For local dimensions, the size is known at compile-time and therefore
the corresponding array shape dimension is set to an integer literal value.
Returns
-------
tuple(shape, strides)
The output tuple fields are arrays of dace symbolic expressions.
"""
dtype = dace.int64
sorted_dims = [dim for _, dim in get_sorted_dims(dims)] if sort_dims else dims
shape = [
neighbor_tables[dim.value].max_neighbors
if dim.kind == DimensionKind.LOCAL
else dace.symbol(unique_name(f"{name}_shape{i}"), dtype)
for i, dim in enumerate(sorted_dims)
]
strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i, _ in enumerate(shape)]
return shape, strides


class ItirToSDFG(eve.NodeVisitor):
param_types: list[ts.TypeSpec]
storage_types: dict[str, ts.TypeSpec]
Expand Down Expand Up @@ -122,7 +150,7 @@ def add_storage(
sort_dimensions: bool = True,
):
if isinstance(type_, ts.FieldType):
shape, strides = map_field_dimensions_to_sdfg_symbols(
shape, strides = _make_array_shape_and_strides(
name, type_.dims, neighbor_tables, sort_dimensions
)
offset = (
Expand Down Expand Up @@ -166,10 +194,12 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
self.add_storage(program_sdfg, str(param.id), type_, neighbor_tables)

# Add connectivities as SDFG storages.
for offset, table in neighbor_tables.items():
scalar_kind = type_translation.get_scalar_kind(table.table.dtype)
for offset, offset_provider in neighbor_tables.items():
scalar_kind = type_translation.get_scalar_kind(offset_provider.table.dtype)
local_dim = Dimension(offset, kind=DimensionKind.LOCAL)
type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind))
type_ = ts.FieldType(
[offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind)
)
self.add_storage(
program_sdfg,
connectivity_identifier(offset),
Expand Down Expand Up @@ -239,9 +269,7 @@ def visit_StencilClosure(

input_names = [str(inp.id) for inp in node.inputs]
neighbor_tables = filter_neighbor_tables(self.offset_provider)
connectivity_names = [
connectivity_identifier(offset) for offset, _ in neighbor_tables.items()
]
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]

output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state)
output_names = [k for k, _ in output_nodes.items()]
Expand Down Expand Up @@ -419,11 +447,11 @@ def _visit_scan_stencil_closure(
output_name: str,
) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], int]:
# extract scan arguments
is_forward, init_carry_value = get_scan_args(node.stencil)
is_forward, init_carry_value = _get_scan_args(node.stencil)
# select the scan dimension based on program argument for column axis
assert self.column_axis
assert isinstance(node.output, SymRef)
scan_dim, scan_dim_index, scan_dtype = get_scan_dim(
scan_dim, scan_dim_index, scan_dtype = _get_scan_dim(
self.column_axis,
self.storage_types,
node.output,
Expand Down Expand Up @@ -589,9 +617,7 @@ def _visit_parallel_stencil_closure(
) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]:
neighbor_tables = filter_neighbor_tables(self.offset_provider)
input_names = [str(inp.id) for inp in node.inputs]
connectivity_names = [
connectivity_identifier(offset) for offset, _ in neighbor_tables.items()
]
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]

# find the scan dimension, same as output dimension, and exclude it from the map domain
map_ranges = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,9 @@ def builtin_list_get(
result_name = unique_var_name()
transformer.context.body.add_scalar(result_name, args[1].dtype, transient=True)
result_node = transformer.context.state.add_access(result_name)
transformer.context.state.add_edge(
transformer.context.state.add_nedge(
args[1].value,
None,
result_node,
None,
dace.Memlet.simple(args[1].value.data, index_value),
)
return [ValueExpr(result_node, args[1].dtype)]
Expand Down Expand Up @@ -576,9 +574,7 @@ def visit_Lambda(
neighbor_tables = (
filter_neighbor_tables(self.offset_provider) if use_neighbor_tables else {}
)
connectivity_names = [
connectivity_identifier(offset) for offset, _ in neighbor_tables.items()
]
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]

# Create the SDFG for the lambda's body
lambda_sdfg = dace.SDFG(func_name)
Expand Down Expand Up @@ -714,8 +710,8 @@ def _visit_call(self, node: itir.FunCall):
nsdfg_inputs[var] = create_memlet_full(store, self.context.body.arrays[store])

neighbor_tables = filter_neighbor_tables(self.offset_provider)
for conn, _ in neighbor_tables.items():
var = connectivity_identifier(conn)
for offset in neighbor_tables.keys():
var = connectivity_identifier(offset)
nsdfg_inputs[var] = create_memlet_full(var, self.context.body.arrays[var])

symbol_mapping = map_nested_sdfg_symbols(self.context.body, func_context.body, nsdfg_inputs)
Expand Down Expand Up @@ -743,8 +739,8 @@ def _visit_call(self, node: itir.FunCall):
store = value.indices[dim]
idx_memlet = nsdfg_inputs[var]
self.context.state.add_edge(store, None, nsdfg_node, var, idx_memlet)
for conn, _ in neighbor_tables.items():
var = connectivity_identifier(conn)
for offset in neighbor_tables.keys():
var = connectivity_identifier(offset)
memlet = nsdfg_inputs[var]
access = self.context.state.add_access(var, debuginfo=nsdfg_node.debuginfo)
self.context.state.add_edge(access, None, nsdfg_node, var, memlet)
Expand Down
22 changes: 2 additions & 20 deletions src/gt4py/next/program_processors/runners/dace_iterator/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import itertools
from typing import Any, Mapping, Optional, Sequence
from typing import Any, Optional, Sequence

import dace

from gt4py.next import Dimension, DimensionKind
from gt4py.next import Dimension
from gt4py.next.common import NeighborTable
from gt4py.next.iterator.ir import Node
from gt4py.next.type_system import type_specifications as ts
Expand Down Expand Up @@ -181,24 +181,6 @@ def unique_var_name():
return unique_name("_var")


def map_field_dimensions_to_sdfg_symbols(
name: str,
dims: Sequence[Dimension],
neighbor_tables: Mapping[str, NeighborTable],
sort_dims: bool,
) -> tuple[list[dace.symbol], list[dace.symbol]]:
dtype = dace.int64
sorted_dims = [dim for _, dim in get_sorted_dims(dims)] if sort_dims else dims
shape = [
neighbor_tables[dim.value].max_neighbors
if dim.kind == DimensionKind.LOCAL
else dace.symbol(unique_name(f"{name}_shape{i}"), dtype)
for i, dim in enumerate(sorted_dims)
]
strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i, _ in enumerate(shape)]
return shape, strides


def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dace.symbol]]:
dtype = dace.int64
shape = [dace.symbol(unique_name(f"{name}_shape{i}"), dtype) for i in range(ndim)]
Expand Down

0 comments on commit de4de8e

Please sign in to comment.