Skip to content

Commit

Permalink
feat[next][dace]: Canonicalize GTIR to enable lowering to SDFG for sp…
Browse files Browse the repository at this point in the history
…ecial cases (#1681)

Introduce a utility function `patch_gtir` that makes the IR compliant
with the requirements of lowering to SDFG:
- Add support for `as_fieldop` expressions that broadcast a scalar value
on a field.
- Add support for GTIR workaround in domain inference, where the domain
is not inferred on the tuple fields that are not referenced.

Additionally, apply some refactoring:
- Use `im.domain` to construct the domain in gtir-to-sdfg tests.
- Rename `ValueExpr` to `DataExpr` and `IteratorIndexExpr` to `ValueExpr`
  • Loading branch information
edopao authored Oct 8, 2024
1 parent 3b6261f commit 9bb6c94
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 182 deletions.
4 changes: 4 additions & 0 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,7 @@ def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunC
return (
isinstance(node, itir.FunCall) and isinstance(node.fun, itir.SymRef) and node.fun.id == fun
)


def is_ref_to(node, ref: str):
return isinstance(node, itir.SymRef) and node.id == ref
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _parse_fieldop_arg(
if isinstance(arg.data_type, ts.ScalarType):
return gtir_to_tasklet.MemletExpr(arg.data_node, sbs.Indices([0]))
elif isinstance(arg.data_type, ts.FieldType):
indices: dict[gtx_common.Dimension, gtir_to_tasklet.IteratorIndexExpr] = {
indices: dict[gtx_common.Dimension, gtir_to_tasklet.ValueExpr] = {
dim: gtir_to_tasklet.SymbolExpr(
dace_gtir_utils.get_map_variable(dim),
IteratorIndexDType,
Expand Down Expand Up @@ -179,7 +179,6 @@ def translate_as_field_op(

# add local storage to compute the field operator over the given domain
domain = dace_gtir_utils.get_domain(domain_expr)
assert isinstance(node.type, ts.FieldType)

if cpm.is_applied_reduce(stencil_expr.expr):
if reduce_identity is not None:
Expand All @@ -188,7 +187,7 @@ def translate_as_field_op(
# the reduce identity value is used to fill the skip values in neighbors list
_, _, reduce_identity = gtir_to_tasklet.get_reduce_params(stencil_expr.expr)

# first visit the list of arguments and build a symbol map
# visit the list of arguments to be passed to the lambda expression
stencil_args = [
_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain, reduce_identity)
for arg in node.args
Expand All @@ -197,7 +196,7 @@ def translate_as_field_op(
# represent the field operator as a mapped tasklet graph, which will range over the field domain
taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder, reduce_identity)
input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args)
assert isinstance(output_expr, gtir_to_tasklet.ValueExpr)
assert isinstance(output_expr, gtir_to_tasklet.DataExpr)
output_desc = output_expr.node.desc(sdfg)

# retrieve the tasklet node which writes the result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def visit_SymRef(


def build_sdfg_from_gtir(
program: gtir.Program,
ir: gtir.Program,
offset_provider: gtx_common.OffsetProvider,
) -> dace.SDFG:
"""
Expand All @@ -653,15 +653,16 @@ def build_sdfg_from_gtir(
As a final step, it runs the `simplify` pass to ensure that the SDFG is in the DaCe canonical form.
Arguments:
program: The GTIR program node to be lowered to SDFG
ir: The GTIR program node to be lowered to SDFG
offset_provider: The definitions of offset providers used by the program node
Returns:
An SDFG in the DaCe canonical form (simplified)
"""
program = gtir_type_inference.infer(program, offset_provider=offset_provider)
ir = gtir_type_inference.infer(ir, offset_provider=offset_provider)
ir = dace_gtir_utils.patch_gtir(ir)
sdfg_genenerator = GTIRToSDFG(offset_provider)
sdfg = sdfg_genenerator.visit(program)
sdfg = sdfg_genenerator.visit(ir)
assert isinstance(sdfg, dace.SDFG)

gtx_transformations.gt_simplify(sdfg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,17 @@
from gt4py.next.type_system import type_specifications as ts


@dataclasses.dataclass(frozen=True)
class DataExpr:
"""Local storage for the computation result returned by a tasklet node."""

node: dace.nodes.AccessNode
dtype: gtir_ts.ListType | ts.ScalarType


@dataclasses.dataclass(frozen=True)
class MemletExpr:
"""Scalar or array data access thorugh a memlet."""
"""Scalar or array data access through a memlet."""

node: dace.nodes.AccessNode
subset: sbs.Indices | sbs.Range
Expand All @@ -44,14 +52,6 @@ class SymbolExpr:
dtype: dace.typeclass


@dataclasses.dataclass(frozen=True)
class ValueExpr:
"""Result of the computation implemented by a tasklet node."""

node: dace.nodes.AccessNode
dtype: gtir_ts.ListType | ts.ScalarType


# Define alias for the elements needed to setup input connections to a map scope
InputConnection: TypeAlias = tuple[
dace.nodes.AccessNode,
Expand All @@ -60,7 +60,7 @@ class ValueExpr:
Optional[str],
]

IteratorIndexExpr: TypeAlias = MemletExpr | SymbolExpr | ValueExpr
ValueExpr: TypeAlias = DataExpr | MemletExpr | SymbolExpr


@dataclasses.dataclass(frozen=True)
Expand All @@ -80,7 +80,7 @@ class IteratorExpr:

field: dace.nodes.AccessNode
dimensions: list[gtx_common.Dimension]
indices: dict[gtx_common.Dimension, IteratorIndexExpr]
indices: dict[gtx_common.Dimension, ValueExpr]


DACE_REDUCTION_MAPPING: dict[str, dace.dtypes.ReductionType] = {
Expand Down Expand Up @@ -193,7 +193,7 @@ def _get_tasklet_result(
dtype: dace.typeclass,
src_node: dace.nodes.Tasklet,
src_connector: str,
) -> ValueExpr:
) -> DataExpr:
temp_name = self.sdfg.temp_data_name()
self.sdfg.add_scalar(temp_name, dtype, transient=True)
data_type = dace_utils.as_scalar_type(str(dtype.as_numpy_dtype()))
Expand All @@ -205,9 +205,9 @@ def _get_tasklet_result(
None,
dace.Memlet(data=temp_name, subset="0"),
)
return ValueExpr(temp_node, data_type)
return DataExpr(temp_node, data_type)

def _visit_deref(self, node: gtir.FunCall) -> MemletExpr | ValueExpr:
def _visit_deref(self, node: gtir.FunCall) -> ValueExpr:
"""
Visit a `deref` node, which represents dereferencing of an iterator.
The iterator is the argument of this node.
Expand All @@ -226,26 +226,26 @@ def _visit_deref(self, node: gtir.FunCall) -> MemletExpr | ValueExpr:
IndexConnectorFmt: Final = "__index_{dim}"

assert len(node.args) == 1
it = self.visit(node.args[0])
arg_expr = self.visit(node.args[0])

if isinstance(it, IteratorExpr):
field_desc = it.field.desc(self.sdfg)
assert len(field_desc.shape) == len(it.dimensions)
if all(isinstance(index, SymbolExpr) for index in it.indices.values()):
if isinstance(arg_expr, IteratorExpr):
field_desc = arg_expr.field.desc(self.sdfg)
assert len(field_desc.shape) == len(arg_expr.dimensions)
if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()):
# when all indices are symblic expressions, we can perform direct field access through a memlet
field_subset = sbs.Range(
(it.indices[dim].value, it.indices[dim].value, 1) # type: ignore[union-attr]
if dim in it.indices
(arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr]
if dim in arg_expr.indices
else (0, size - 1, 1)
for dim, size in zip(it.dimensions, field_desc.shape)
for dim, size in zip(arg_expr.dimensions, field_desc.shape)
)
return MemletExpr(it.field, field_subset)
return MemletExpr(arg_expr.field, field_subset)

else:
# we use a tasklet to dereference an iterator when one or more indices are the result of some computation,
# either indirection through connectivity table or dynamic cartesian offset.
assert all(dim in it.indices for dim in it.dimensions)
field_indices = [(dim, it.indices[dim]) for dim in it.dimensions]
assert all(dim in arg_expr.indices for dim in arg_expr.dimensions)
field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions]
index_connectors = [
IndexConnectorFmt.format(dim=dim.value)
for dim, index in field_indices
Expand All @@ -268,7 +268,7 @@ def _visit_deref(self, node: gtir.FunCall) -> MemletExpr | ValueExpr:
)
# add new termination point for the field parameter
self._add_entry_memlet_path(
it.field,
arg_expr.field,
sbs.Range.from_array(field_desc),
deref_node,
"field",
Expand All @@ -285,7 +285,7 @@ def _visit_deref(self, node: gtir.FunCall) -> MemletExpr | ValueExpr:
deref_connector,
)

elif isinstance(index_expr, ValueExpr):
elif isinstance(index_expr, DataExpr):
self._add_edge(
index_expr.node,
None,
Expand All @@ -296,14 +296,14 @@ def _visit_deref(self, node: gtir.FunCall) -> MemletExpr | ValueExpr:
else:
assert isinstance(index_expr, SymbolExpr)

dtype = it.field.desc(self.sdfg).dtype
dtype = arg_expr.field.desc(self.sdfg).dtype
return self._get_tasklet_result(dtype, deref_node, "val")

else:
assert isinstance(it, MemletExpr)
return it
# dereferencing a scalar or a literal node results in the node itself
return arg_expr

def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr:
def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr:
assert len(node.args) == 2

assert isinstance(node.args[0], gtir.OffsetLiteral)
Expand Down Expand Up @@ -429,9 +429,9 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr:
)

assert isinstance(node.type, gtir_ts.ListType)
return ValueExpr(neighbors_node, node.type)
return DataExpr(neighbors_node, node.type)

def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr:
def _visit_reduce(self, node: gtir.FunCall) -> DataExpr:
op_name, reduce_init, reduce_identity = get_reduce_params(node)
dtype = reduce_identity.dtype

Expand All @@ -447,7 +447,7 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr:
# ensure that we leave the visitor in the same state as we entered
self.reduce_identity = prev_reduce_identity

assert isinstance(input_expr, MemletExpr | ValueExpr)
assert isinstance(input_expr, MemletExpr | DataExpr)
input_desc = input_expr.node.desc(self.sdfg)
assert isinstance(input_desc, dace.data.Array)

Expand Down Expand Up @@ -487,7 +487,7 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr:
dace.Memlet(data=temp_name, subset="0"),
)
assert isinstance(node.type, ts.ScalarType)
return ValueExpr(temp_node, node.type)
return DataExpr(temp_node, node.type)

def _split_shift_args(
self, args: list[gtir.Expr]
Expand Down Expand Up @@ -518,11 +518,11 @@ def _visit_shift_multidim(
return offset_provider_arg, offset_value_arg, it

def _make_cartesian_shift(
self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: IteratorIndexExpr
self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: ValueExpr
) -> IteratorExpr:
"""Implements cartesian shift along one dimension."""
assert offset_dim in it.dimensions
new_index: SymbolExpr | ValueExpr
new_index: SymbolExpr | DataExpr
assert offset_dim in it.indices
index_expr = it.indices[offset_dim]
if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr):
Expand Down Expand Up @@ -563,7 +563,7 @@ def _make_cartesian_shift(
dynamic_offset_tasklet,
input_connector,
)
elif isinstance(input_expr, ValueExpr):
elif isinstance(input_expr, DataExpr):
self._add_edge(
input_expr.node,
None,
Expand All @@ -588,15 +588,15 @@ def _make_cartesian_shift(

def _make_dynamic_neighbor_offset(
self,
offset_expr: MemletExpr | ValueExpr,
offset_expr: MemletExpr | DataExpr,
offset_table_node: dace.nodes.AccessNode,
origin_index: SymbolExpr,
) -> ValueExpr:
) -> DataExpr:
"""
Implements access to neighbor connectivity table by means of a tasklet node.
It requires a dynamic offset value, either obtained from a field/scalar argument (`MemletExpr`)
or computed by another tasklet (`ValueExpr`).
or computed by another tasklet (`DataExpr`).
"""
new_index_connector = "neighbor_index"
tasklet_node = self._add_tasklet(
Expand Down Expand Up @@ -635,7 +635,7 @@ def _make_unstructured_shift(
it: IteratorExpr,
connectivity: gtx_common.Connectivity,
offset_table_node: dace.nodes.AccessNode,
offset_expr: IteratorIndexExpr,
offset_expr: ValueExpr,
) -> IteratorExpr:
"""Implements shift in unstructured domain by means of a neighbor table."""
assert connectivity.neighbor_axis in it.dimensions
Expand Down Expand Up @@ -702,7 +702,7 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr:
it, offset_provider, offset_table_node, offset_expr
)

def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | MemletExpr | ValueExpr:
def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | ValueExpr:
if cpm.is_call_to(node, "deref"):
return self._visit_deref(node)

Expand All @@ -719,10 +719,10 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | MemletExpr | Value
assert isinstance(node.fun, gtir.SymRef)

node_internals = []
node_connections: dict[str, MemletExpr | ValueExpr] = {}
node_connections: dict[str, MemletExpr | DataExpr] = {}
for i, arg in enumerate(node.args):
arg_expr = self.visit(arg)
if isinstance(arg_expr, MemletExpr | ValueExpr):
if isinstance(arg_expr, MemletExpr | DataExpr):
# the argument value is the result of a tasklet node or direct field access
connector = f"__inp_{i}"
node_connections[connector] = arg_expr
Expand All @@ -745,7 +745,7 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | MemletExpr | Value
)

for connector, arg_expr in node_connections.items():
if isinstance(arg_expr, ValueExpr):
if isinstance(arg_expr, DataExpr):
self._add_edge(
arg_expr.node,
None,
Expand All @@ -768,11 +768,11 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | MemletExpr | Value

def visit_Lambda(
self, node: gtir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr]
) -> tuple[list[InputConnection], ValueExpr]:
) -> tuple[list[InputConnection], DataExpr]:
for p, arg in zip(node.params, args, strict=True):
self.symbol_map[str(p.id)] = arg
output_expr: MemletExpr | SymbolExpr | ValueExpr = self.visit(node.expr)
if isinstance(output_expr, ValueExpr):
output_expr: MemletExpr | SymbolExpr | DataExpr = self.visit(node.expr)
if isinstance(output_expr, DataExpr):
return self.input_connections, output_expr

if isinstance(output_expr, MemletExpr):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import dace

from gt4py import eve
from gt4py.next import common as gtx_common
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
Expand Down Expand Up @@ -106,3 +107,41 @@ def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType:
return ts.TupleType(
types=[get_tuple_type(d) if isinstance(d, tuple) else d.data_type for d in data]
)


def patch_gtir(ir: gtir.Program) -> gtir.Program:
"""
Make the IR compliant with the requirements of lowering to SDFG.
Applies canonicalization of as_fieldop expressions as well as some temporary workarounds.
This allows to lower the IR to SDFG for some special cases.
"""

class PatchGTIR(eve.PreserveLocationVisitor, eve.NodeTranslator):
def visit_FunCall(self, node: gtir.FunCall) -> gtir.Node:
if cpm.is_applied_as_fieldop(node):
assert isinstance(node.fun, gtir.FunCall)
assert isinstance(node.type, ts.FieldType)

# Handle the case of fieldop without domain. This case should never happen, but domain
# inference currently produces this kind of nodes for unreferenced tuple fields.
# TODO(tehrengruber): remove this workaround once domain ineference supports this case
if len(node.fun.args) == 1:
return gtir.Literal(value="0", type=node.type.dtype)

assert len(node.fun.args) == 2
stencil = node.fun.args[0]

# Canonicalize as_fieldop: always expect a lambda expression.
# Here we replace the call to deref with a lambda expression and empty arguments list.
if cpm.is_ref_to(stencil, "deref"):
node.fun.args[0] = gtir.Lambda(
expr=gtir.FunCall(fun=stencil, args=node.args), params=[]
)
node.args = []

node.args = [self.visit(arg) for arg in node.args]
node.fun = self.visit(node.fun)
return node

return PatchGTIR().visit(ir)
Loading

0 comments on commit 9bb6c94

Please sign in to comment.