diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
index 8df4723502..a4b074a4b6 100644
--- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
+++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
@@ -14,6 +14,7 @@
from typing import TypeGuard
from gt4py.next.iterator import ir as itir
+from gt4py.next.iterator.ir_utils import ir_makers as im
def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
@@ -24,3 +25,13 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
and isinstance(arg.fun.fun, itir.SymRef)
and arg.fun.fun.id == "lift"
)
+
+
+def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]:
+ """Match expression of the form `(λ(...) → ...)(...)`."""
+ return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda)
+
+
+def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]:
+ """Match expression of the form `if_(cond, true_branch, false_branch)`."""
+ return isinstance(node, itir.FunCall) and node.fun == im.ref("if_")
diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py
index 94a2646422..4337e8512a 100644
--- a/src/gt4py/next/iterator/ir_utils/ir_makers.py
+++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py
@@ -12,7 +12,8 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
-from typing import Callable, Union
+import typing
+from typing import Callable, Iterable, Union
from gt4py._core import definitions as core_defs
from gt4py.next.iterator import ir as itir
@@ -242,16 +243,31 @@ class let:
--------
>>> str(let("a", "b")("a")) # doctest: +ELLIPSIS
'(λ(a) → a)(b)'
- >>> str(let("a", 1,
- ... "b", 2
+ >>> str(let(("a", 1),
+ ... ("b", 2)
... )(plus("a", "b")))
'(λ(a, b) → a + b)(1, 2)'
"""
- def __init__(self, *vars_and_values):
- assert len(vars_and_values) % 2 == 0
- self.vars = vars_and_values[0::2]
- self.init_forms = vars_and_values[1::2]
+ @typing.overload
+ def __init__(self, var: str | itir.Sym, init_form: itir.Expr): ...
+
+ @typing.overload
+ def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr]]): ...
+
+ def __init__(self, *args):
+ if all(isinstance(arg, tuple) and len(arg) == 2 for arg in args):
+ assert isinstance(args, tuple)
+ assert all(isinstance(arg, tuple) and len(arg) == 2 for arg in args)
+ self.vars = [var for var, _ in args]
+ self.init_forms = [init_form for _, init_form in args]
+ elif len(args) == 2:
+ self.vars = [args[0]]
+ self.init_forms = [args[1]]
+ else:
+ raise TypeError(
+ "Invalid arguments: expected a variable name and an init form or a list thereof."
+ )
def __call__(self, form):
return call(lambda_(*self.vars)(form))(*self.init_forms)
diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py
new file mode 100644
index 0000000000..4336649d06
--- /dev/null
+++ b/src/gt4py/next/iterator/ir_utils/misc.py
@@ -0,0 +1,79 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import dataclasses
+from collections import ChainMap
+
+from gt4py import eve
+from gt4py.eve import utils as eve_utils
+from gt4py.next.iterator import ir as itir
+from gt4py.next.iterator.ir_utils import ir_makers as im
+
+
+@dataclasses.dataclass(frozen=True)
+class CannonicalizeBoundSymbolNames(eve.NodeTranslator):
+ """
+ Given an iterator expression cannonicalize all bound symbol names.
+
+ If two such expression are in the same scope and equal so are their values.
+
+ >>> testee1 = im.lambda_("a")(im.plus("a", "b"))
+ >>> cannonicalized_testee1 = CannonicalizeBoundSymbolNames.apply(testee1)
+ >>> str(cannonicalized_testee1)
+ 'λ(_csym_1) → _csym_1 + b'
+
+ >>> testee2 = im.lambda_("c")(im.plus("c", "b"))
+ >>> cannonicalized_testee2 = CannonicalizeBoundSymbolNames.apply(testee2)
+ >>> assert cannonicalized_testee1 == cannonicalized_testee2
+ """
+
+ _uids: eve_utils.UIDGenerator = dataclasses.field(
+ init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_csym")
+ )
+
+ @classmethod
+ def apply(cls, node: itir.Expr):
+ return cls().visit(node, sym_map=ChainMap({}))
+
+ def visit_Lambda(self, node: itir.Lambda, *, sym_map: ChainMap):
+ sym_map = sym_map.new_child()
+ for param in node.params:
+ sym_map[str(param.id)] = self._uids.sequential_id()
+
+ return im.lambda_(*sym_map.values())(self.visit(node.expr, sym_map=sym_map))
+
+ def visit_SymRef(self, node: itir.SymRef, *, sym_map: dict[str, str]):
+ return im.ref(sym_map[node.id]) if node.id in sym_map else node
+
+
+def is_equal(a: itir.Expr, b: itir.Expr):
+ """
+ Return true if two expressions have provably equal values.
+
+ Be aware that this function might return false even though the two expression have the same
+ value.
+
+ >>> testee1 = im.lambda_("a")(im.plus("a", "b"))
+ >>> testee2 = im.lambda_("c")(im.plus("c", "b"))
+ >>> assert is_equal(testee1, testee2)
+
+ >>> testee1 = im.lambda_("a")(im.plus("a", "b"))
+ >>> testee2 = im.lambda_("c")(im.plus("c", "d"))
+ >>> assert not is_equal(testee1, testee2)
+ """
+ # TODO(tehrengruber): Extend this function cover more cases than just those with equal
+ # structure, e.g., by also canonicalization of the structure.
+ return a == b or (
+ CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b)
+ )
diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py
index 42bbf28909..51daffed05 100644
--- a/src/gt4py/next/iterator/transforms/collapse_tuple.py
+++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py
@@ -11,22 +11,29 @@
# distribution for a copy of the license or check .
#
# SPDX-License-Identifier: GPL-3.0-or-later
-from dataclasses import dataclass
+import dataclasses
+import enum
+import functools
+import operator
from typing import Optional
from gt4py import eve
+from gt4py.eve import utils as eve_utils
from gt4py.next import type_inference
from gt4py.next.iterator import ir, type_inference as it_type_inference
+from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc
+from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let
+from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda
class UnknownLength:
pass
-def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | type[UnknownLength]:
- if node_types:
- type_ = node_types[id(elem)]
- # global inference should always give a length, function should fail otherwise
+def _get_tuple_size(elem: ir.Node, use_global_information: bool) -> int | type[UnknownLength]:
+ if use_global_information:
+ type_ = elem.annex.type
+ # global inference should always give a length, fail otherwise
assert isinstance(type_, it_type_inference.Val) and isinstance(
type_.dtype, it_type_inference.Tuple
)
@@ -47,7 +54,31 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t
return len(type_.dtype)
-@dataclass(frozen=True)
+def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr):
+ """Given a itir.FunCall return a new call with one of its argument replaced."""
+ return ir.FunCall(
+ fun=node.fun, args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)]
+ )
+
+
+def _is_trivial_make_tuple_call(node: ir.Expr):
+ """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof."""
+ if not (isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple")):
+ return False
+ if not all(
+ isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg)
+ for arg in node.args
+ ):
+ return False
+ return True
+
+
+# TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first,
+# transform each node until no transformations apply anymore, whenever a node is to be transformed
+# go through all available transformation and apply them. However the final result here still
+# reads a little convoluted and is also different to how we write other transformations. We
+# should revisit the pattern here and try to find a more general mechanism.
+@dataclasses.dataclass(frozen=True)
class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator):
"""
Simplifies `make_tuple`, `tuple_get` calls.
@@ -56,10 +87,44 @@ class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator):
- `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i`
"""
+ # TODO(tehrengruber): This Flag mechanism is a little low level. What we actually want
+ # is something like a pass manager, where for each pattern we have a corresponding
+ # transformation, etc.
+ class Flag(enum.Flag):
+ #: `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t`
+ COLLAPSE_MAKE_TUPLE_TUPLE_GET = enum.auto()
+ #: `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i`
+ COLLAPSE_TUPLE_GET_MAKE_TUPLE = enum.auto()
+ #: `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))`
+ PROPAGATE_TUPLE_GET = enum.auto()
+ #: `{1, 2}` -> `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)`
+ LETIFY_MAKE_TUPLE_ELEMENTS = enum.auto()
+ #: `let(tup, {trivial_expr1, trivial_expr2})(foo(tup))`
+ #: -> `foo({trivial_expr1, trivial_expr2})`
+ INLINE_TRIVIAL_MAKE_TUPLE = enum.auto()
+ #: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]`
+ PROPAGATE_TO_IF_ON_TUPLES = enum.auto()
+ #: `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))`
+ PROPAGATE_NESTED_LET = enum.auto()
+ #: `let(a, 1)(a)` -> `1`
+ INLINE_TRIVIAL_LET = enum.auto()
+
+ @classmethod
+ def all(self): # noqa: A003 # shadowing a python builtin
+ return functools.reduce(operator.or_, self.__members__.values())
+
ignore_tuple_size: bool
- collapse_make_tuple_tuple_get: bool
- collapse_tuple_get_make_tuple: bool
use_global_type_inference: bool
+ flags: Flag = Flag.all()
+
+ PRESERVED_ANNEX_ATTRS = ("type",)
+
+ # we use one UID generator per instance such that the generated ids are
+ # stable across multiple runs (required for caching to properly work)
+ _letify_make_tuple_uids: eve_utils.UIDGenerator = dataclasses.field(
+ init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_tuple_el")
+ )
+
_node_types: Optional[dict[int, type_inference.Type]] = None
@classmethod
@@ -68,34 +133,77 @@ def apply(
node: ir.Node,
*,
ignore_tuple_size: bool = False,
- # the following options are mostly for allowing separate testing of the modes
- collapse_make_tuple_tuple_get: bool = True,
- collapse_tuple_get_make_tuple: bool = True,
use_global_type_inference: bool = False,
+ remove_letified_make_tuple_elements: bool = True,
+ # manually passing flags is mostly for allowing separate testing of the modes
+ flags=None,
) -> ir.Node:
"""
Simplifies `make_tuple`, `tuple_get` calls.
- If `ignore_tuple_size`, apply the transformation even if length of the inner tuple
- is greater than the length of the outer tuple.
+ Arguments:
+ node: The node to transform.
+
+ Keyword arguments:
+ ignore_tuple_size: Apply the transformation even if length of the inner tuple is greater
+ than the length of the outer tuple.
+ use_global_type_inference: Run global type inference to determine tuple sizes.
+ remove_letified_make_tuple_elements: Run `InlineLambdas` as a post-processing step
+ to remove left-overs from `LETIFY_MAKE_TUPLE_ELEMENTS` transformation.
+ `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` -> {1, 2}`
"""
- node_types = it_type_inference.infer_all(node) if use_global_type_inference else None
- return cls(
- ignore_tuple_size,
- collapse_make_tuple_tuple_get,
- collapse_tuple_get_make_tuple,
- use_global_type_inference,
- node_types,
+ flags = flags or cls.flags
+ if use_global_type_inference:
+ it_type_inference.infer_all(node, save_to_annex=True)
+
+ new_node = cls(
+ ignore_tuple_size=ignore_tuple_size,
+ use_global_type_inference=use_global_type_inference,
+ flags=flags,
).visit(node)
- def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
- if (
- self.collapse_make_tuple_tuple_get
- and node.fun == ir.SymRef(id="make_tuple")
- and all(
- isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get")
- for arg in node.args
+ # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important
+ # as otherwise two equal expressions containing a tuple will not be equal anymore
+ # and the CSE pass can not remove them.
+ # TODO(tehrengruber): test case for `scan(lambda carry: {1, 2})`
+ # (see solve_nonhydro_stencil_52_like_z_q_tup)
+ if remove_letified_make_tuple_elements:
+ new_node = InlineLambdas.apply(
+ new_node, opcount_preserving=True, force_inline_lambda_args=False
)
+
+ return new_node
+
+ def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
+ node = self.generic_visit(node)
+ return self.fp_transform(node)
+
+ def fp_transform(self, node: ir.Node) -> ir.Node:
+ while True:
+ new_node = self.transform(node)
+ if new_node is None:
+ break
+ assert new_node != node
+ node = new_node
+ return node
+
+ def transform(self, node: ir.Node) -> Optional[ir.Node]:
+ if not isinstance(node, ir.FunCall):
+ return None
+
+ for transformation in self.Flag:
+ if self.flags & transformation:
+ assert isinstance(transformation.name, str)
+ method = getattr(self, f"transform_{transformation.name.lower()}")
+ result = method(node)
+ if result is not None:
+ return result
+ return None
+
+ def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]:
+ if node.fun == ir.SymRef(id="make_tuple") and all(
+ isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get")
+ for arg in node.args
):
# `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t`
assert isinstance(node.args[0], ir.FunCall)
@@ -104,17 +212,19 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
for i, v in enumerate(node.args):
assert isinstance(v, ir.FunCall)
assert isinstance(v.args[0], ir.Literal)
- if not (int(v.args[0].value) == i and v.args[1] == first_expr):
+ if not (int(v.args[0].value) == i and ir_misc.is_equal(v.args[1], first_expr)):
# tuple argument differs, just continue with the rest of the tree
- return self.generic_visit(node)
+ return None
- if self.ignore_tuple_size or _get_tuple_size(first_expr, self._node_types) == len(
- node.args
- ):
+ if self.ignore_tuple_size or _get_tuple_size(
+ first_expr, self.use_global_type_inference
+ ) == len(node.args):
return first_expr
+ return None
+
+ def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]:
if (
- self.collapse_tuple_get_make_tuple
- and node.fun == ir.SymRef(id="tuple_get")
+ node.fun == ir.SymRef(id="tuple_get")
and isinstance(node.args[1], ir.FunCall)
and node.args[1].fun == ir.SymRef(id="make_tuple")
and isinstance(node.args[0], ir.Literal)
@@ -127,4 +237,106 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
make_tuple_call.args
), f"Index {idx} is out of bounds for tuple of size {len(make_tuple_call.args)}"
return node.args[1].args[idx]
- return self.generic_visit(node)
+ return None
+
+ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]:
+ if node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[0], ir.Literal):
+ # TODO(tehrengruber): extend to general symbols as long as the tail call in the let
+ # does not capture
+ # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))`
+ if is_let(node.args[1]):
+ idx, let_expr = node.args
+ return im.call(
+ im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let
+ self.fp_transform(im.tuple_get(idx.value, let_expr.fun.expr)) # type: ignore[attr-defined] # ensured by is_let
+ )
+ )(
+ *let_expr.args # type: ignore[attr-defined] # ensured by is_let
+ )
+ elif isinstance(node.args[1], ir.FunCall) and node.args[1].fun == im.ref("if_"):
+ idx = node.args[0]
+ cond, true_branch, false_branch = node.args[1].args
+ return im.if_(
+ cond,
+ self.fp_transform(im.tuple_get(idx.value, true_branch)),
+ self.fp_transform(im.tuple_get(idx.value, false_branch)),
+ )
+ return None
+
+ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir.Node]:
+ if node.fun == ir.SymRef(id="make_tuple"):
+ # `make_tuple(expr1, expr1)`
+ # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))`
+ bound_vars: dict[str, ir.Expr] = {}
+ new_args: list[ir.Expr] = []
+ for arg in node.args:
+ if (
+ isinstance(node, ir.FunCall)
+ and node.fun == im.ref("make_tuple")
+ and not _is_trivial_make_tuple_call(node)
+ ):
+ el_name = self._letify_make_tuple_uids.sequential_id()
+ new_args.append(im.ref(el_name))
+ bound_vars[el_name] = arg
+ else:
+ new_args.append(arg)
+
+ if bound_vars:
+ return self.fp_transform(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) # type: ignore[arg-type] # mypy not smart enough
+ return None
+
+ def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]:
+ if is_let(node):
+ # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))`
+ # -> `foo(make_tuple(trivial_expr1, trivial_expr2))`
+ eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args]
+ if any(eligible_params):
+ return self.visit(inline_lambda(node, eligible_params=eligible_params))
+ return None
+
+ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.Node]:
+ if not node.fun == im.ref("if_"):
+ # TODO(tehrengruber): This significantly increases the size of the tree. Revisit.
+ # TODO(tehrengruber): Only inline if type of branch value is a tuple.
+ # Examples:
+ # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]`
+ # `let (b, if cond then {1, 2} else {3, 4})) b[0]`
+ # -> `if cond then let(b, {1, 2})(b[0]) else let(b, {3, 4})(b[0])`
+ for i, arg in enumerate(node.args):
+ if is_if_call(arg):
+ cond, true_branch, false_branch = arg.args
+ new_true_branch = self.fp_transform(_with_altered_arg(node, i, true_branch))
+ new_false_branch = self.fp_transform(_with_altered_arg(node, i, false_branch))
+ return im.if_(cond, new_true_branch, new_false_branch)
+ return None
+
+ def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]:
+ if is_let(node):
+ # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))`
+ outer_vars = {}
+ inner_vars = {}
+ original_inner_expr = node.fun.expr # type: ignore[attr-defined] # ensured by is_let
+ for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let
+ assert arg_sym not in inner_vars # TODO(tehrengruber): fix collisions
+ if is_let(arg):
+ for sym, val in zip(arg.fun.params, arg.args): # type: ignore[attr-defined] # ensured by is_let
+ assert sym not in outer_vars # TODO(tehrengruber): fix collisions
+ outer_vars[sym] = val
+ inner_vars[arg_sym] = arg.fun.expr # type: ignore[attr-defined] # ensured by is_let
+ else:
+ inner_vars[arg_sym] = arg
+ if outer_vars:
+ return self.fp_transform(
+ im.let(*outer_vars.items())( # type: ignore[arg-type] # mypy not smart enough
+ self.fp_transform(im.let(*inner_vars.items())(original_inner_expr))
+ )
+ )
+ return None
+
+ def transform_inline_trivial_let(self, node: ir.FunCall) -> Optional[ir.Node]:
+ if is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let
+ # `let(a, 1)(a)` -> `1`
+ for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let
+ if node.fun.expr == im.ref(arg_sym.id): # type: ignore[attr-defined] # ensured by is_let
+ return arg
+ return None
diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py
index fe14a8f580..b9dcc094c4 100644
--- a/src/gt4py/next/iterator/transforms/pass_manager.py
+++ b/src/gt4py/next/iterator/transforms/pass_manager.py
@@ -117,6 +117,10 @@ def apply_common_transforms(
# to limit number of times global type inference is executed, only in the last iterations.
use_global_type_inference=inlined == ir,
)
+ # This pass is required such that a deref outside of a
+ # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the
+ # `tuple_get` is removed by the `CollapseTuple` pass.
+ inlined = PropagateDeref.apply(inlined)
if inlined == ir:
break
diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py
index 783e54ede0..9f8bff7a84 100644
--- a/src/gt4py/next/iterator/transforms/propagate_deref.py
+++ b/src/gt4py/next/iterator/transforms/propagate_deref.py
@@ -15,6 +15,7 @@
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.eve.pattern_matching import ObjectPattern as P
from gt4py.next.iterator import ir
+from gt4py.next.iterator.ir_utils import ir_makers as im
# TODO(tehrengruber): This pass can be generalized to all builtins, e.g.
@@ -56,4 +57,12 @@ def visit_FunCall(self, node: ir.FunCall):
),
args=lambda_args,
)
+ elif (
+ node.fun == im.ref("deref")
+ and isinstance(node.args[0], ir.FunCall)
+ and node.args[0].fun == im.ref("if_")
+ ):
+ cond, true_branch, false_branch = node.args[0].args
+ return im.if_(cond, im.deref(true_branch), im.deref(false_branch))
+
return self.generic_visit(node)
diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py
index 1444b0a64f..330f66bee5 100644
--- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py
+++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py
@@ -20,7 +20,11 @@ def test_simple_make_tuple_tuple_get():
tuple_of_size_2 = im.make_tuple("first", "second")
testee = im.make_tuple(im.tuple_get(0, tuple_of_size_2), im.tuple_get(1, tuple_of_size_2))
- actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False)
+ actual = CollapseTuple.apply(
+ testee,
+ remove_letified_make_tuple_elements=False,
+ flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET,
+ )
expected = tuple_of_size_2
assert actual == expected
@@ -32,7 +36,11 @@ def test_nested_make_tuple_tuple_get():
im.tuple_get(0, tup_of_size2_from_lambda), im.tuple_get(1, tup_of_size2_from_lambda)
)
- actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False)
+ actual = CollapseTuple.apply(
+ testee,
+ remove_letified_make_tuple_elements=False,
+ flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET,
+ )
assert actual == tup_of_size2_from_lambda
@@ -42,7 +50,11 @@ def test_different_tuples_make_tuple_tuple_get():
t1 = im.make_tuple("foo1", "bar1")
testee = im.make_tuple(im.tuple_get(0, t0), im.tuple_get(1, t1))
- actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False)
+ actual = CollapseTuple.apply(
+ testee,
+ remove_letified_make_tuple_elements=False,
+ flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET,
+ )
assert actual == testee # did nothing
@@ -50,24 +62,137 @@ def test_different_tuples_make_tuple_tuple_get():
def test_incompatible_order_make_tuple_tuple_get():
tuple_of_size_2 = im.make_tuple("first", "second")
testee = im.make_tuple(im.tuple_get(1, tuple_of_size_2), im.tuple_get(0, tuple_of_size_2))
- actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False)
+ actual = CollapseTuple.apply(
+ testee,
+ remove_letified_make_tuple_elements=False,
+ flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET,
+ )
assert actual == testee # did nothing
def test_incompatible_size_make_tuple_tuple_get():
testee = im.make_tuple(im.tuple_get(0, im.make_tuple("first", "second")))
- actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False)
+ actual = CollapseTuple.apply(
+ testee,
+ remove_letified_make_tuple_elements=False,
+ flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET,
+ )
assert actual == testee # did nothing
def test_merged_with_smaller_outer_size_make_tuple_tuple_get():
testee = im.make_tuple(im.tuple_get(0, im.make_tuple("first", "second")))
- actual = CollapseTuple.apply(testee, ignore_tuple_size=True)
+ actual = CollapseTuple.apply(
+ testee, ignore_tuple_size=True, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET
+ )
assert actual == im.make_tuple("first", "second")
def test_simple_tuple_get_make_tuple():
expected = im.ref("bar")
testee = im.tuple_get(1, im.make_tuple("foo", expected))
- actual = CollapseTuple.apply(testee, collapse_make_tuple_tuple_get=False)
+ actual = CollapseTuple.apply(
+ testee,
+ remove_letified_make_tuple_elements=False,
+ flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE,
+ )
+ assert expected == actual
+
+
+def test_propagate_tuple_get():
+ expected = im.let(("el1", 1), ("el2", 2))(im.tuple_get(0, im.make_tuple("el1", "el2")))
+ testee = im.tuple_get(0, im.let(("el1", 1), ("el2", 2))(im.make_tuple("el1", "el2")))
+ actual = CollapseTuple.apply(
+ testee,
+ remove_letified_make_tuple_elements=False,
+ flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET,
+ )
assert expected == actual
+
+
+def test_letify_make_tuple_elements():
+ opaque_call = im.call("opaque")()
+ testee = im.make_tuple(opaque_call, opaque_call)
+ expected = im.let(("_tuple_el_1", opaque_call), ("_tuple_el_2", opaque_call))(
+ im.make_tuple("_tuple_el_1", "_tuple_el_2")
+ )
+
+ actual = CollapseTuple.apply(
+ testee,
+ remove_letified_make_tuple_elements=False,
+ flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS,
+ )
+ assert actual == expected
+
+
+def test_letify_make_tuple_with_trivial_elements():
+ testee = im.let(("a", 1), ("b", 2))(im.make_tuple("a", "b"))
+ expected = testee # did nothing
+
+ actual = CollapseTuple.apply(
+ testee,
+ remove_letified_make_tuple_elements=False,
+ flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS,
+ )
+ assert actual == expected
+
+
+def test_inline_trivial_make_tuple():
+ testee = im.let("tup", im.make_tuple("a", "b"))("tup")
+ expected = im.make_tuple("a", "b")
+
+ actual = CollapseTuple.apply(
+ testee,
+ remove_letified_make_tuple_elements=False,
+ flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE,
+ )
+ assert actual == expected
+
+
+def test_propagate_to_if_on_tuples():
+ testee = im.tuple_get(0, im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4)))
+ expected = im.if_(
+ "cond", im.tuple_get(0, im.make_tuple(1, 2)), im.tuple_get(0, im.make_tuple(3, 4))
+ )
+ actual = CollapseTuple.apply(
+ testee,
+ remove_letified_make_tuple_elements=False,
+ flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,
+ )
+ assert actual == expected
+
+
+def test_propagate_to_if_on_tuples_with_let():
+ testee = im.let("val", im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4)))(
+ im.tuple_get(0, "val")
+ )
+ expected = im.if_(
+ "cond", im.tuple_get(0, im.make_tuple(1, 2)), im.tuple_get(0, im.make_tuple(3, 4))
+ )
+ actual = CollapseTuple.apply(
+ testee,
+ remove_letified_make_tuple_elements=True,
+ flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES
+ | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS,
+ )
+ assert actual == expected
+
+
+def test_propagate_nested_lift():
+ testee = im.let("a", im.let("b", 1)("a_val"))("a")
+ expected = im.let("b", 1)(im.let("a", "a_val")("a"))
+ actual = CollapseTuple.apply(
+ testee,
+ remove_letified_make_tuple_elements=False,
+ flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET,
+ )
+ assert actual == expected
+
+
+def test_if_on_tuples_with_let():
+ testee = im.let("val", im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4)))(
+ im.tuple_get(0, "val")
+ )
+ expected = im.if_("cond", 1, 3)
+ actual = CollapseTuple.apply(testee, remove_letified_make_tuple_elements=False)
+ assert actual == expected
diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py
index 065095e1c2..fb7720f4d7 100644
--- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py
+++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py
@@ -213,15 +213,14 @@ def is_let(node: ir.Expr):
testee = im.plus(
im.let(
- "c",
- im.let(
- "a",
- 1,
- "b",
- 2,
- )(im.plus("a", "b")),
- "d",
- 3,
+ (
+ "c",
+ im.let(
+ ("a", 1),
+ ("b", 2),
+ )(im.plus("a", "b")),
+ ),
+ ("d", 3),
)(im.plus("c", "d")),
4,
)