Skip to content

Commit

Permalink
Merge branch 'main' into romanc/warn-if-dace-not-found
Browse files Browse the repository at this point in the history
  • Loading branch information
egparedes authored Oct 18, 2024
2 parents db2759b + 0a27c7a commit a2f5c6f
Show file tree
Hide file tree
Showing 19 changed files with 603 additions and 245 deletions.
5 changes: 4 additions & 1 deletion src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,10 @@ def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall:
_visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where

def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return self.visit(node.args[0], **kwargs)
expr = self.visit(node.args[0], **kwargs)
if isinstance(node.args[0].type, ts.ScalarType):
return im.as_fieldop(im.ref("deref"))(expr)
return expr

def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return self._map(self.visit(node.func, **kwargs), *node.args)
Expand Down
12 changes: 8 additions & 4 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,14 @@ def __str__(self) -> str:
return pformat(self)

def __hash__(self) -> int:
return hash(type(self)) ^ hash(
tuple(
hash(tuple(v)) if isinstance(v, list) else hash(v)
for v in self.iter_children_values()
return hash(
(
type(self),
*(
tuple(v) if isinstance(v, list) else v
for (k, v) in self.iter_children_items()
if k not in ["location", "type"]
),
)
)

Expand Down
10 changes: 10 additions & 0 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
)


def is_applied_map(arg: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expressions of the form `map(λ(...) → ...)(...)`."""
return (
isinstance(arg, itir.FunCall)
and isinstance(arg.fun, itir.FunCall)
and isinstance(arg.fun.fun, itir.SymRef)
and arg.fun.fun.id == "map_"
)


def is_applied_reduce(arg: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expressions of the form `reduce(λ(...) → ...)(...)`."""
return (
Expand Down
19 changes: 5 additions & 14 deletions src/gt4py/next/iterator/transforms/fuse_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
from typing import TypeGuard

from gt4py.eve import NodeTranslator, traits
from gt4py.eve.utils import UIDGenerator
Expand All @@ -16,14 +15,6 @@
from gt4py.next.iterator.transforms import inline_lambdas


def _is_map(node: ir.Node) -> TypeGuard[ir.FunCall]:
return (
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.FunCall)
and node.fun.fun == ir.SymRef(id="map_")
)


@dataclasses.dataclass(frozen=True)
class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
"""
Expand Down Expand Up @@ -58,10 +49,10 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda:

def visit_FunCall(self, node: ir.FunCall, **kwargs):
node = self.generic_visit(node)
if _is_map(node) or cpm.is_applied_reduce(node):
if any(_is_map(arg) for arg in node.args):
if cpm.is_applied_map(node) or cpm.is_applied_reduce(node):
if any(cpm.is_applied_map(arg) for arg in node.args):
first_param = (
0 if _is_map(node) else 1
0 if cpm.is_applied_map(node) else 1
) # index of the first param of op that maps to args (0 for map, 1 for reduce)
assert isinstance(node.fun, ir.FunCall)
assert isinstance(node.fun.args[0], (ir.Lambda, ir.SymRef))
Expand All @@ -76,7 +67,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
new_params.append(outer_op.params[0])

for i in range(len(node.args)):
if _is_map(node.args[i]):
if cpm.is_applied_map(node.args[i]):
map_call = node.args[i]
assert isinstance(map_call, ir.FunCall)
assert isinstance(map_call.fun, ir.FunCall)
Expand All @@ -102,7 +93,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
new_body
) # removes one level of nesting (the recursive inliner could simplify more, however this can also be done on the full tree later)
new_op = ir.Lambda(params=new_params, expr=new_body)
if _is_map(node):
if cpm.is_applied_map(node):
return ir.FunCall(
fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args
)
Expand Down
45 changes: 45 additions & 0 deletions src/gt4py/next/iterator/transforms/prune_casts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.type_system import type_specifications as ts


class PruneCasts(PreserveLocationVisitor, NodeTranslator):
"""
Removes cast expressions where the argument is already in the target type.
This transformation requires the IR to be fully type-annotated,
therefore it should be applied after type-inference.
"""

def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
node = self.generic_visit(node)

if not cpm.is_call_to(node, "cast_"):
return node

value, type_constructor = node.args

assert (
value.type
and isinstance(type_constructor, ir.SymRef)
and (type_constructor.id in ir.TYPEBUILTINS)
)
dtype = ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper()))

if value.type == dtype:
return value

return node

@classmethod
def apply(cls, node: ir.Node) -> ir.Node:
return cls().visit(node)
4 changes: 3 additions & 1 deletion src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,11 @@ def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType:
def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.TupleType:
domain = self.visit(node.domain, ctx=ctx)
assert isinstance(domain, it_ts.DomainType)
assert domain.dims != "unknown"
assert node.dtype
return type_info.apply_to_primitive_constituents(
lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), node.dtype
lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), # type: ignore[arg-type] # ensured by domain.dims != "unknown" above
node.dtype,
)

def visit_IfStmt(self, node: itir.IfStmt, *, ctx) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/type_system/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class NamedRangeType(ts.TypeSpec):

@dataclasses.dataclass(frozen=True)
class DomainType(ts.DataType):
dims: list[common.Dimension]
dims: list[common.Dimension] | Literal["unknown"]


@dataclasses.dataclass(frozen=True)
Expand Down
27 changes: 23 additions & 4 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,17 +271,36 @@ def _convert_as_fieldop_input_to_iterator(

@_register_builtin_type_synthesizer
def as_fieldop(
stencil: TypeSynthesizer, domain: it_ts.DomainType, offset_provider: common.OffsetProvider
stencil: TypeSynthesizer,
domain: Optional[it_ts.DomainType] = None,
*,
offset_provider: common.OffsetProvider,
) -> TypeSynthesizer:
# In case we don't have a domain argument to `as_fieldop` we can not infer the exact result
# type. In order to still allow some passes which don't need this information to run before the
# domain inference, we continue with a dummy domain. One example is the CollapseTuple pass
# which only needs information about the structure, e.g. how many tuple elements does this node
# have, but not the dimensions of a field.
# Note that it might appear as if using the TraceShift pass would allow us to deduce the return
# type of `as_fieldop` without a domain, but this is not the case, since we don't have
# information on the ordering of dimensions. In this example
# `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)`
# it is unclear if the result has dimension I, J or J, I.
if domain is None:
domain = it_ts.DomainType(dims="unknown")

@TypeSynthesizer
def applied_as_fieldop(*fields) -> ts.FieldType:
def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType:
stencil_return = stencil(
*(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields),
offset_provider=offset_provider,
)
assert isinstance(stencil_return, ts.DataType)
return type_info.apply_to_primitive_constituents(
lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type), stencil_return
lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type)
if domain.dims != "unknown"
else ts.DeferredType(constraint=ts.FieldType),
stencil_return,
)

return applied_as_fieldop
Expand Down Expand Up @@ -329,7 +348,7 @@ def applied_reduce(*args: it_ts.ListType, offset_provider: common.OffsetProvider


@_register_builtin_type_synthesizer
def shift(*offset_literals, offset_provider) -> TypeSynthesizer:
def shift(*offset_literals, offset_provider: common.OffsetProvider) -> TypeSynthesizer:
@TypeSynthesizer
def apply_shift(
it: it_ts.IteratorType | ts.DeferredType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ def as_dace_type(type_: ts.ScalarType) -> dace.typeclass:
raise ValueError(f"Scalar type '{type_}' not supported.")


def as_scalar_type(typestr: str) -> ts.ScalarType:
"""Obtain GT4Py scalar type from generic numpy string representation."""
def as_itir_type(dtype: dace.typeclass) -> ts.ScalarType:
"""Get GT4Py scalar representation of a DaCe type."""
type_name = str(dtype.as_numpy_dtype())
try:
kind = getattr(ts.ScalarKind, typestr.upper())
kind = getattr(ts.ScalarKind, type_name.upper())
except AttributeError as ex:
raise ValueError(f"Data type {typestr} not supported.") from ex
raise ValueError(f"Data type {type_name} not supported.") from ex
return ts.ScalarType(kind)


Expand Down
Loading

0 comments on commit a2f5c6f

Please sign in to comment.