Skip to content

Commit

Permalink
feat[next][dace]: GTIR-to-DaCe lowering of map-reduce (only full conn…
Browse files Browse the repository at this point in the history
…ectivity) (#1683)

This PR adds support for lowering of `map_` and `make_const_list`
builtin functions. However, the current implementation only supports
neighbor tables with full connectivity (no skip values). The support for
skip values will be added in next PR.

To be noted:
- This PR generalizes the handling of tasklets without arguments inside
a map scope. The return type for `input_connections` is extended to
contain a `TaskletConnection` variant, which is lowered to an empty edge
from map entry node to the tasklet node.
- The result of `make_const_list` is a scalar value to be broadcasted on
a local field. However, in order to keep the lowering simple, this value
is represented as a 1D 1-element array (`shape=(1,)`).
  • Loading branch information
edopao authored Oct 17, 2024
1 parent 3f7fcee commit 0a27c7a
Show file tree
Hide file tree
Showing 9 changed files with 453 additions and 234 deletions.
10 changes: 10 additions & 0 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
)


def is_applied_map(arg: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expressions of the form `map(λ(...) → ...)(...)`."""
return (
isinstance(arg, itir.FunCall)
and isinstance(arg.fun, itir.FunCall)
and isinstance(arg.fun.fun, itir.SymRef)
and arg.fun.fun.id == "map_"
)


def is_applied_reduce(arg: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expressions of the form `reduce(λ(...) → ...)(...)`."""
return (
Expand Down
19 changes: 5 additions & 14 deletions src/gt4py/next/iterator/transforms/fuse_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
from typing import TypeGuard

from gt4py.eve import NodeTranslator, traits
from gt4py.eve.utils import UIDGenerator
Expand All @@ -16,14 +15,6 @@
from gt4py.next.iterator.transforms import inline_lambdas


def _is_map(node: ir.Node) -> TypeGuard[ir.FunCall]:
return (
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.FunCall)
and node.fun.fun == ir.SymRef(id="map_")
)


@dataclasses.dataclass(frozen=True)
class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
"""
Expand Down Expand Up @@ -58,10 +49,10 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda:

def visit_FunCall(self, node: ir.FunCall, **kwargs):
node = self.generic_visit(node)
if _is_map(node) or cpm.is_applied_reduce(node):
if any(_is_map(arg) for arg in node.args):
if cpm.is_applied_map(node) or cpm.is_applied_reduce(node):
if any(cpm.is_applied_map(arg) for arg in node.args):
first_param = (
0 if _is_map(node) else 1
0 if cpm.is_applied_map(node) else 1
) # index of the first param of op that maps to args (0 for map, 1 for reduce)
assert isinstance(node.fun, ir.FunCall)
assert isinstance(node.fun.args[0], (ir.Lambda, ir.SymRef))
Expand All @@ -76,7 +67,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
new_params.append(outer_op.params[0])

for i in range(len(node.args)):
if _is_map(node.args[i]):
if cpm.is_applied_map(node.args[i]):
map_call = node.args[i]
assert isinstance(map_call, ir.FunCall)
assert isinstance(map_call.fun, ir.FunCall)
Expand All @@ -102,7 +93,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
new_body
) # removes one level of nesting (the recursive inliner could simplify more, however this can also be done on the full tree later)
new_op = ir.Lambda(params=new_params, expr=new_body)
if _is_map(node):
if cpm.is_applied_map(node):
return ir.FunCall(
fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ def as_dace_type(type_: ts.ScalarType) -> dace.typeclass:
raise ValueError(f"Scalar type '{type_}' not supported.")


def as_scalar_type(typestr: str) -> ts.ScalarType:
"""Obtain GT4Py scalar type from generic numpy string representation."""
def as_itir_type(dtype: dace.typeclass) -> ts.ScalarType:
"""Get GT4Py scalar representation of a DaCe type."""
type_name = str(dtype.as_numpy_dtype())
try:
kind = getattr(ts.ScalarKind, typestr.upper())
kind = getattr(ts.ScalarKind, type_name.upper())
except AttributeError as ex:
raise ValueError(f"Data type {typestr} not supported.") from ex
raise ValueError(f"Data type {type_name} not supported.") from ex
return ts.ScalarType(kind)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@

import abc
import dataclasses
from typing import TYPE_CHECKING, Iterable, Optional, Protocol, TypeAlias
from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, TypeAlias

import dace
import dace.subsets as sbs

from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next.ffront import fbuiltins as gtx_fbuiltins
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.type_system import type_specifications as itir_ts
Expand All @@ -32,16 +33,29 @@
from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg


IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes


@dataclasses.dataclass(frozen=True)
class Field:
data_node: dace.nodes.AccessNode
data_type: ts.FieldType | ts.ScalarType


FieldopDomain: TypeAlias = list[
tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]
]
"""
Domain of a field operator represented as a list of tuples with 3 elements:
- dimension definition
- symbolic expression for lower bound (inclusive)
- symbolic expression for upper bound (exclusive)
"""


FieldopResult: TypeAlias = Field | tuple[Field | tuple, ...]
"""Result of a field operator, can be either a field or a tuple fields."""


INDEX_DTYPE: Final[dace.typeclass] = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType)
"""Data type used for field indexing."""


class PrimitiveTranslator(Protocol):
Expand Down Expand Up @@ -81,11 +95,11 @@ def _parse_fieldop_arg(
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_sdfg.SDFGBuilder,
domain: list[
tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]
],
domain: FieldopDomain,
reduce_identity: Optional[gtir_dataflow.SymbolExpr],
) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr:
"""Helper method to visit an expression passed as argument to a field operator."""

arg = sdfg_builder.visit(
node,
sdfg=sdfg,
Expand All @@ -101,10 +115,7 @@ def _parse_fieldop_arg(
return gtir_dataflow.MemletExpr(arg.data_node, sbs.Indices([0]))
elif isinstance(arg.data_type, ts.FieldType):
indices: dict[gtx_common.Dimension, gtir_dataflow.ValueExpr] = {
dim: gtir_dataflow.SymbolExpr(
dace_gtir_utils.get_map_variable(dim),
IteratorIndexDType,
)
dim: gtir_dataflow.SymbolExpr(dace_gtir_utils.get_map_variable(dim), INDEX_DTYPE)
for dim, _, _ in domain
}
dims = arg.data_type.dims + (
Expand All @@ -120,12 +131,11 @@ def _parse_fieldop_arg(
def _create_temporary_field(
sdfg: dace.SDFG,
state: dace.SDFGState,
domain: list[
tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]
],
domain: FieldopDomain,
node_type: ts.FieldType,
output_desc: dace.data.Data,
dataflow_output: gtir_dataflow.DataflowOutputEdge,
) -> Field:
"""Helper method to allocate a temporary field where to write the output of a field operator."""
domain_dims, _, domain_ubs = zip(*domain)
field_dims = list(domain_dims)
# It should be enough to allocate an array with shape (upper_bound - lower_bound)
Expand All @@ -138,6 +148,7 @@ def _create_temporary_field(
# eliminate most of transient arrays.
field_shape = list(domain_ubs)

output_desc = dataflow_output.result.node.desc(sdfg)
if isinstance(output_desc, dace.data.Array):
assert isinstance(node_type.dtype, itir_ts.ListType)
assert isinstance(node_type.dtype.element_type, ts.ScalarType)
Expand All @@ -157,7 +168,31 @@ def _create_temporary_field(
return Field(field_node, field_type)


def translate_as_field_op(
def extract_domain(node: gtir.Node) -> FieldopDomain:
"""
Visits the domain of a field operator and returns a list of dimensions and
the corresponding lower and upper bounds. The returned lower bound is inclusive,
the upper bound is exclusive: [lower_bound, upper_bound[
"""
assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain"))

domain = []
for named_range in node.args:
assert cpm.is_call_to(named_range, "named_range")
assert len(named_range.args) == 3
axis = named_range.args[0]
assert isinstance(axis, gtir.AxisLiteral)
lower_bound, upper_bound = (
dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg))
for arg in named_range.args[1:3]
)
dim = gtx_common.Dimension(axis.value, axis.kind)
domain.append((dim, lower_bound, upper_bound))

return domain


def translate_as_fieldop(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
Expand Down Expand Up @@ -188,25 +223,55 @@ def translate_as_field_op(
assert isinstance(domain_expr, gtir.FunCall)

# parse the domain of the field operator
domain = dace_gtir_utils.get_domain(domain_expr)
domain = extract_domain(domain_expr)

# The reduction identity value is used in place of skip values when building
# a list of neighbor values in the unstructured domain.
#
# A reduction on neighbor values can be either expressed in local view (itir):
# vertices @ u⟨ Vertexₕ: [0, nvertices) ⟩
# ← as_fieldop(
# λ(it) → reduce(plus, 0)(neighbors(V2Eₒ, it)), u⟨ Vertexₕ: [0, nvertices) ⟩
# )(edges);
#
# or in field view (gtir):
# vertices @ u⟨ Vertexₕ: [0, nvertices) ⟩
# ← as_fieldop(λ(it) → reduce(plus, 0)(·it), u⟨ Vertexₕ: [0, nvertices) ⟩)(
# as_fieldop(λ(it) → neighbors(V2Eₒ, it), u⟨ Vertexₕ: [0, nvertices) ⟩)(edges)
# );
#
# In local view, the list of neighbors is (recursively) built while visiting
# the current expression.
# In field view, the list of neighbors is built as argument to the current
# expression. Therefore, the reduction identity value needs to be passed to
# the argument visitor (`reduce_identity_for_args = reduce_identity`).
if cpm.is_applied_reduce(stencil_expr.expr):
if reduce_identity is not None:
raise NotImplementedError("nested reductions not supported.")

# the reduce identity value is used to fill the skip values in neighbors list
_, _, reduce_identity = gtir_dataflow.get_reduce_params(stencil_expr.expr)
raise NotImplementedError("Nested reductions are not supported.")
_, _, reduce_identity_for_args = gtir_dataflow.get_reduce_params(stencil_expr.expr)
elif cpm.is_call_to(stencil_expr.expr, "neighbors"):
# When the visitor hits a neighbors expression, we stop carrying the reduce
# identity further (`reduce_identity_for_args = None`) because the reduce
# identity value is filled in place of skip values in the context of neighbors
# itself, not in the arguments context.
# Besides, setting `reduce_identity_for_args = None` enables a sanity check
# that the sequence 'reduce(V2E) -> neighbors(V2E) -> reduce(C2E) -> neighbors(C2E)'
# is accepted, while 'reduce(V2E) -> reduce(C2E) -> neighbors(V2E) -> neighbors(C2E)'
# is not. The latter sequence would raise the 'NotImplementedError' exception above.
reduce_identity_for_args = None
else:
reduce_identity_for_args = reduce_identity

# visit the list of arguments to be passed to the lambda expression
stencil_args = [
_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain, reduce_identity)
_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain, reduce_identity_for_args)
for arg in node.args
]

# represent the field operator as a mapped tasklet graph, which will range over the field domain
taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder, reduce_identity)
input_edges, output = taskgen.visit(stencil_expr, args=stencil_args)
output_desc = output.expr.node.desc(sdfg)
output_desc = output.result.node.desc(sdfg)

domain_index = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain])
if isinstance(node.type.dtype, itir_ts.ListType):
Expand All @@ -220,11 +285,17 @@ def translate_as_field_op(
output_subset = sbs.Range.from_indices(domain_index)

# create map range corresponding to the field operator domain
map_ranges = {dace_gtir_utils.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain}
me, mx = sdfg_builder.add_map("field_op", state, map_ranges)
me, mx = sdfg_builder.add_map(
"fieldop",
state,
ndrange={
dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}"
for dim, lower_bound, upper_bound in domain
},
)

# allocate local temporary storage for the result field
result_field = _create_temporary_field(sdfg, state, domain, node.type, output_desc)
result_field = _create_temporary_field(sdfg, state, domain, node.type, output)

# here we setup the edges from the map entry node
for edge in input_edges:
Expand Down Expand Up @@ -439,7 +510,7 @@ def translate_tuple_get(

if not isinstance(node.args[0], gtir.Literal):
raise ValueError("Tuple can only be subscripted with compile-time constants.")
assert node.args[0].type == dace_utils.as_scalar_type(gtir.INTEGER_INDEX_BUILTIN)
assert node.args[0].type == dace_utils.as_itir_type(INDEX_DTYPE)
index = int(node.args[0].value)

data_nodes = sdfg_builder.visit(
Expand Down Expand Up @@ -566,7 +637,7 @@ def translate_symbol_ref(
if TYPE_CHECKING:
# Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol
__primitive_translators: list[PrimitiveTranslator] = [
translate_as_field_op,
translate_as_fieldop,
translate_if,
translate_literal,
translate_make_tuple,
Expand Down
Loading

0 comments on commit 0a27c7a

Please sign in to comment.