diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 8df4723502..a4b074a4b6 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -14,6 +14,7 @@ from typing import TypeGuard from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: @@ -24,3 +25,13 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: and isinstance(arg.fun.fun, itir.SymRef) and arg.fun.fun.id == "lift" ) + + +def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: + """Match expression of the form `(λ(...) → ...)(...)`.""" + return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) + + +def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]: + """Match expression of the form `if_(cond, true_branch, false_branch)`.""" + return isinstance(node, itir.FunCall) and node.fun == im.ref("if_") diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 94a2646422..4337e8512a 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -12,7 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Callable, Union +import typing +from typing import Callable, Iterable, Union from gt4py._core import definitions as core_defs from gt4py.next.iterator import ir as itir @@ -242,16 +243,31 @@ class let: -------- >>> str(let("a", "b")("a")) # doctest: +ELLIPSIS '(λ(a) → a)(b)' - >>> str(let("a", 1, - ... "b", 2 + >>> str(let(("a", 1), + ... ("b", 2) ... )(plus("a", "b"))) '(λ(a, b) → a + b)(1, 2)' """ - def __init__(self, *vars_and_values): - assert len(vars_and_values) % 2 == 0 - self.vars = vars_and_values[0::2] - self.init_forms = vars_and_values[1::2] + @typing.overload + def __init__(self, var: str | itir.Sym, init_form: itir.Expr): ... + + @typing.overload + def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr]]): ... + + def __init__(self, *args): + if all(isinstance(arg, tuple) and len(arg) == 2 for arg in args): + assert isinstance(args, tuple) + assert all(isinstance(arg, tuple) and len(arg) == 2 for arg in args) + self.vars = [var for var, _ in args] + self.init_forms = [init_form for _, init_form in args] + elif len(args) == 2: + self.vars = [args[0]] + self.init_forms = [args[1]] + else: + raise TypeError( + "Invalid arguments: expected a variable name and an init form or a list thereof." + ) def __call__(self, form): return call(lambda_(*self.vars)(form))(*self.init_forms) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py new file mode 100644 index 0000000000..4336649d06 --- /dev/null +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -0,0 +1,79 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import dataclasses +from collections import ChainMap + +from gt4py import eve +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im + + +@dataclasses.dataclass(frozen=True) +class CannonicalizeBoundSymbolNames(eve.NodeTranslator): + """ + Given an iterator expression cannonicalize all bound symbol names. + + If two such expression are in the same scope and equal so are their values. + + >>> testee1 = im.lambda_("a")(im.plus("a", "b")) + >>> cannonicalized_testee1 = CannonicalizeBoundSymbolNames.apply(testee1) + >>> str(cannonicalized_testee1) + 'λ(_csym_1) → _csym_1 + b' + + >>> testee2 = im.lambda_("c")(im.plus("c", "b")) + >>> cannonicalized_testee2 = CannonicalizeBoundSymbolNames.apply(testee2) + >>> assert cannonicalized_testee1 == cannonicalized_testee2 + """ + + _uids: eve_utils.UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_csym") + ) + + @classmethod + def apply(cls, node: itir.Expr): + return cls().visit(node, sym_map=ChainMap({})) + + def visit_Lambda(self, node: itir.Lambda, *, sym_map: ChainMap): + sym_map = sym_map.new_child() + for param in node.params: + sym_map[str(param.id)] = self._uids.sequential_id() + + return im.lambda_(*sym_map.values())(self.visit(node.expr, sym_map=sym_map)) + + def visit_SymRef(self, node: itir.SymRef, *, sym_map: dict[str, str]): + return im.ref(sym_map[node.id]) if node.id in sym_map else node + + +def is_equal(a: itir.Expr, b: itir.Expr): + """ + Return true if two expressions have provably equal values. + + Be aware that this function might return false even though the two expression have the same + value. + + >>> testee1 = im.lambda_("a")(im.plus("a", "b")) + >>> testee2 = im.lambda_("c")(im.plus("c", "b")) + >>> assert is_equal(testee1, testee2) + + >>> testee1 = im.lambda_("a")(im.plus("a", "b")) + >>> testee2 = im.lambda_("c")(im.plus("c", "d")) + >>> assert not is_equal(testee1, testee2) + """ + # TODO(tehrengruber): Extend this function cover more cases than just those with equal + # structure, e.g., by also canonicalization of the structure. + return a == b or ( + CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b) + ) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 42bbf28909..51daffed05 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -11,22 +11,29 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass +import dataclasses +import enum +import functools +import operator from typing import Optional from gt4py import eve +from gt4py.eve import utils as eve_utils from gt4py.next import type_inference from gt4py.next.iterator import ir, type_inference as it_type_inference +from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let +from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda class UnknownLength: pass -def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | type[UnknownLength]: - if node_types: - type_ = node_types[id(elem)] - # global inference should always give a length, function should fail otherwise +def _get_tuple_size(elem: ir.Node, use_global_information: bool) -> int | type[UnknownLength]: + if use_global_information: + type_ = elem.annex.type + # global inference should always give a length, fail otherwise assert isinstance(type_, it_type_inference.Val) and isinstance( type_.dtype, it_type_inference.Tuple ) @@ -47,7 +54,31 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t return len(type_.dtype) -@dataclass(frozen=True) +def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): + """Given a itir.FunCall return a new call with one of its argument replaced.""" + return ir.FunCall( + fun=node.fun, args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)] + ) + + +def _is_trivial_make_tuple_call(node: ir.Expr): + """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" + if not (isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple")): + return False + if not all( + isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg) + for arg in node.args + ): + return False + return True + + +# TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first, +# transform each node until no transformations apply anymore, whenever a node is to be transformed +# go through all available transformation and apply them. However the final result here still +# reads a little convoluted and is also different to how we write other transformations. We +# should revisit the pattern here and try to find a more general mechanism. +@dataclasses.dataclass(frozen=True) class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): """ Simplifies `make_tuple`, `tuple_get` calls. @@ -56,10 +87,44 @@ class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): - `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` """ + # TODO(tehrengruber): This Flag mechanism is a little low level. What we actually want + # is something like a pass manager, where for each pattern we have a corresponding + # transformation, etc. + class Flag(enum.Flag): + #: `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` + COLLAPSE_MAKE_TUPLE_TUPLE_GET = enum.auto() + #: `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` + COLLAPSE_TUPLE_GET_MAKE_TUPLE = enum.auto() + #: `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` + PROPAGATE_TUPLE_GET = enum.auto() + #: `{1, 2}` -> `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` + LETIFY_MAKE_TUPLE_ELEMENTS = enum.auto() + #: `let(tup, {trivial_expr1, trivial_expr2})(foo(tup))` + #: -> `foo({trivial_expr1, trivial_expr2})` + INLINE_TRIVIAL_MAKE_TUPLE = enum.auto() + #: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` + PROPAGATE_TO_IF_ON_TUPLES = enum.auto() + #: `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` + PROPAGATE_NESTED_LET = enum.auto() + #: `let(a, 1)(a)` -> `1` + INLINE_TRIVIAL_LET = enum.auto() + + @classmethod + def all(self): # noqa: A003 # shadowing a python builtin + return functools.reduce(operator.or_, self.__members__.values()) + ignore_tuple_size: bool - collapse_make_tuple_tuple_get: bool - collapse_tuple_get_make_tuple: bool use_global_type_inference: bool + flags: Flag = Flag.all() + + PRESERVED_ANNEX_ATTRS = ("type",) + + # we use one UID generator per instance such that the generated ids are + # stable across multiple runs (required for caching to properly work) + _letify_make_tuple_uids: eve_utils.UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_tuple_el") + ) + _node_types: Optional[dict[int, type_inference.Type]] = None @classmethod @@ -68,34 +133,77 @@ def apply( node: ir.Node, *, ignore_tuple_size: bool = False, - # the following options are mostly for allowing separate testing of the modes - collapse_make_tuple_tuple_get: bool = True, - collapse_tuple_get_make_tuple: bool = True, use_global_type_inference: bool = False, + remove_letified_make_tuple_elements: bool = True, + # manually passing flags is mostly for allowing separate testing of the modes + flags=None, ) -> ir.Node: """ Simplifies `make_tuple`, `tuple_get` calls. - If `ignore_tuple_size`, apply the transformation even if length of the inner tuple - is greater than the length of the outer tuple. + Arguments: + node: The node to transform. + + Keyword arguments: + ignore_tuple_size: Apply the transformation even if length of the inner tuple is greater + than the length of the outer tuple. + use_global_type_inference: Run global type inference to determine tuple sizes. + remove_letified_make_tuple_elements: Run `InlineLambdas` as a post-processing step + to remove left-overs from `LETIFY_MAKE_TUPLE_ELEMENTS` transformation. + `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` -> {1, 2}` """ - node_types = it_type_inference.infer_all(node) if use_global_type_inference else None - return cls( - ignore_tuple_size, - collapse_make_tuple_tuple_get, - collapse_tuple_get_make_tuple, - use_global_type_inference, - node_types, + flags = flags or cls.flags + if use_global_type_inference: + it_type_inference.infer_all(node, save_to_annex=True) + + new_node = cls( + ignore_tuple_size=ignore_tuple_size, + use_global_type_inference=use_global_type_inference, + flags=flags, ).visit(node) - def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: - if ( - self.collapse_make_tuple_tuple_get - and node.fun == ir.SymRef(id="make_tuple") - and all( - isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get") - for arg in node.args + # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important + # as otherwise two equal expressions containing a tuple will not be equal anymore + # and the CSE pass can not remove them. + # TODO(tehrengruber): test case for `scan(lambda carry: {1, 2})` + # (see solve_nonhydro_stencil_52_like_z_q_tup) + if remove_letified_make_tuple_elements: + new_node = InlineLambdas.apply( + new_node, opcount_preserving=True, force_inline_lambda_args=False ) + + return new_node + + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: + node = self.generic_visit(node) + return self.fp_transform(node) + + def fp_transform(self, node: ir.Node) -> ir.Node: + while True: + new_node = self.transform(node) + if new_node is None: + break + assert new_node != node + node = new_node + return node + + def transform(self, node: ir.Node) -> Optional[ir.Node]: + if not isinstance(node, ir.FunCall): + return None + + for transformation in self.Flag: + if self.flags & transformation: + assert isinstance(transformation.name, str) + method = getattr(self, f"transform_{transformation.name.lower()}") + result = method(node) + if result is not None: + return result + return None + + def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: + if node.fun == ir.SymRef(id="make_tuple") and all( + isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get") + for arg in node.args ): # `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` assert isinstance(node.args[0], ir.FunCall) @@ -104,17 +212,19 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: for i, v in enumerate(node.args): assert isinstance(v, ir.FunCall) assert isinstance(v.args[0], ir.Literal) - if not (int(v.args[0].value) == i and v.args[1] == first_expr): + if not (int(v.args[0].value) == i and ir_misc.is_equal(v.args[1], first_expr)): # tuple argument differs, just continue with the rest of the tree - return self.generic_visit(node) + return None - if self.ignore_tuple_size or _get_tuple_size(first_expr, self._node_types) == len( - node.args - ): + if self.ignore_tuple_size or _get_tuple_size( + first_expr, self.use_global_type_inference + ) == len(node.args): return first_expr + return None + + def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: if ( - self.collapse_tuple_get_make_tuple - and node.fun == ir.SymRef(id="tuple_get") + node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[1], ir.FunCall) and node.args[1].fun == ir.SymRef(id="make_tuple") and isinstance(node.args[0], ir.Literal) @@ -127,4 +237,106 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: make_tuple_call.args ), f"Index {idx} is out of bounds for tuple of size {len(make_tuple_call.args)}" return node.args[1].args[idx] - return self.generic_visit(node) + return None + + def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: + if node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[0], ir.Literal): + # TODO(tehrengruber): extend to general symbols as long as the tail call in the let + # does not capture + # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` + if is_let(node.args[1]): + idx, let_expr = node.args + return im.call( + im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let + self.fp_transform(im.tuple_get(idx.value, let_expr.fun.expr)) # type: ignore[attr-defined] # ensured by is_let + ) + )( + *let_expr.args # type: ignore[attr-defined] # ensured by is_let + ) + elif isinstance(node.args[1], ir.FunCall) and node.args[1].fun == im.ref("if_"): + idx = node.args[0] + cond, true_branch, false_branch = node.args[1].args + return im.if_( + cond, + self.fp_transform(im.tuple_get(idx.value, true_branch)), + self.fp_transform(im.tuple_get(idx.value, false_branch)), + ) + return None + + def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir.Node]: + if node.fun == ir.SymRef(id="make_tuple"): + # `make_tuple(expr1, expr1)` + # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` + bound_vars: dict[str, ir.Expr] = {} + new_args: list[ir.Expr] = [] + for arg in node.args: + if ( + isinstance(node, ir.FunCall) + and node.fun == im.ref("make_tuple") + and not _is_trivial_make_tuple_call(node) + ): + el_name = self._letify_make_tuple_uids.sequential_id() + new_args.append(im.ref(el_name)) + bound_vars[el_name] = arg + else: + new_args.append(arg) + + if bound_vars: + return self.fp_transform(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) # type: ignore[arg-type] # mypy not smart enough + return None + + def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: + if is_let(node): + # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` + # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` + eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] + if any(eligible_params): + return self.visit(inline_lambda(node, eligible_params=eligible_params)) + return None + + def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.Node]: + if not node.fun == im.ref("if_"): + # TODO(tehrengruber): This significantly increases the size of the tree. Revisit. + # TODO(tehrengruber): Only inline if type of branch value is a tuple. + # Examples: + # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` + # `let (b, if cond then {1, 2} else {3, 4})) b[0]` + # -> `if cond then let(b, {1, 2})(b[0]) else let(b, {3, 4})(b[0])` + for i, arg in enumerate(node.args): + if is_if_call(arg): + cond, true_branch, false_branch = arg.args + new_true_branch = self.fp_transform(_with_altered_arg(node, i, true_branch)) + new_false_branch = self.fp_transform(_with_altered_arg(node, i, false_branch)) + return im.if_(cond, new_true_branch, new_false_branch) + return None + + def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: + if is_let(node): + # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` + outer_vars = {} + inner_vars = {} + original_inner_expr = node.fun.expr # type: ignore[attr-defined] # ensured by is_let + for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let + assert arg_sym not in inner_vars # TODO(tehrengruber): fix collisions + if is_let(arg): + for sym, val in zip(arg.fun.params, arg.args): # type: ignore[attr-defined] # ensured by is_let + assert sym not in outer_vars # TODO(tehrengruber): fix collisions + outer_vars[sym] = val + inner_vars[arg_sym] = arg.fun.expr # type: ignore[attr-defined] # ensured by is_let + else: + inner_vars[arg_sym] = arg + if outer_vars: + return self.fp_transform( + im.let(*outer_vars.items())( # type: ignore[arg-type] # mypy not smart enough + self.fp_transform(im.let(*inner_vars.items())(original_inner_expr)) + ) + ) + return None + + def transform_inline_trivial_let(self, node: ir.FunCall) -> Optional[ir.Node]: + if is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let + # `let(a, 1)(a)` -> `1` + for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let + if node.fun.expr == im.ref(arg_sym.id): # type: ignore[attr-defined] # ensured by is_let + return arg + return None diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index fe14a8f580..b9dcc094c4 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -117,6 +117,10 @@ def apply_common_transforms( # to limit number of times global type inference is executed, only in the last iterations. use_global_type_inference=inlined == ir, ) + # This pass is required such that a deref outside of a + # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the + # `tuple_get` is removed by the `CollapseTuple` pass. + inlined = PropagateDeref.apply(inlined) if inlined == ir: break diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 783e54ede0..9f8bff7a84 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -15,6 +15,7 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im # TODO(tehrengruber): This pass can be generalized to all builtins, e.g. @@ -56,4 +57,12 @@ def visit_FunCall(self, node: ir.FunCall): ), args=lambda_args, ) + elif ( + node.fun == im.ref("deref") + and isinstance(node.args[0], ir.FunCall) + and node.args[0].fun == im.ref("if_") + ): + cond, true_branch, false_branch = node.args[0].args + return im.if_(cond, im.deref(true_branch), im.deref(false_branch)) + return self.generic_visit(node) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 1444b0a64f..330f66bee5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -20,7 +20,11 @@ def test_simple_make_tuple_tuple_get(): tuple_of_size_2 = im.make_tuple("first", "second") testee = im.make_tuple(im.tuple_get(0, tuple_of_size_2), im.tuple_get(1, tuple_of_size_2)) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) expected = tuple_of_size_2 assert actual == expected @@ -32,7 +36,11 @@ def test_nested_make_tuple_tuple_get(): im.tuple_get(0, tup_of_size2_from_lambda), im.tuple_get(1, tup_of_size2_from_lambda) ) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) assert actual == tup_of_size2_from_lambda @@ -42,7 +50,11 @@ def test_different_tuples_make_tuple_tuple_get(): t1 = im.make_tuple("foo1", "bar1") testee = im.make_tuple(im.tuple_get(0, t0), im.tuple_get(1, t1)) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) assert actual == testee # did nothing @@ -50,24 +62,137 @@ def test_different_tuples_make_tuple_tuple_get(): def test_incompatible_order_make_tuple_tuple_get(): tuple_of_size_2 = im.make_tuple("first", "second") testee = im.make_tuple(im.tuple_get(1, tuple_of_size_2), im.tuple_get(0, tuple_of_size_2)) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) assert actual == testee # did nothing def test_incompatible_size_make_tuple_tuple_get(): testee = im.make_tuple(im.tuple_get(0, im.make_tuple("first", "second"))) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) assert actual == testee # did nothing def test_merged_with_smaller_outer_size_make_tuple_tuple_get(): testee = im.make_tuple(im.tuple_get(0, im.make_tuple("first", "second"))) - actual = CollapseTuple.apply(testee, ignore_tuple_size=True) + actual = CollapseTuple.apply( + testee, ignore_tuple_size=True, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET + ) assert actual == im.make_tuple("first", "second") def test_simple_tuple_get_make_tuple(): expected = im.ref("bar") testee = im.tuple_get(1, im.make_tuple("foo", expected)) - actual = CollapseTuple.apply(testee, collapse_make_tuple_tuple_get=False) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE, + ) + assert expected == actual + + +def test_propagate_tuple_get(): + expected = im.let(("el1", 1), ("el2", 2))(im.tuple_get(0, im.make_tuple("el1", "el2"))) + testee = im.tuple_get(0, im.let(("el1", 1), ("el2", 2))(im.make_tuple("el1", "el2"))) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET, + ) assert expected == actual + + +def test_letify_make_tuple_elements(): + opaque_call = im.call("opaque")() + testee = im.make_tuple(opaque_call, opaque_call) + expected = im.let(("_tuple_el_1", opaque_call), ("_tuple_el_2", opaque_call))( + im.make_tuple("_tuple_el_1", "_tuple_el_2") + ) + + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + ) + assert actual == expected + + +def test_letify_make_tuple_with_trivial_elements(): + testee = im.let(("a", 1), ("b", 2))(im.make_tuple("a", "b")) + expected = testee # did nothing + + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + ) + assert actual == expected + + +def test_inline_trivial_make_tuple(): + testee = im.let("tup", im.make_tuple("a", "b"))("tup") + expected = im.make_tuple("a", "b") + + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE, + ) + assert actual == expected + + +def test_propagate_to_if_on_tuples(): + testee = im.tuple_get(0, im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4))) + expected = im.if_( + "cond", im.tuple_get(0, im.make_tuple(1, 2)), im.tuple_get(0, im.make_tuple(3, 4)) + ) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + ) + assert actual == expected + + +def test_propagate_to_if_on_tuples_with_let(): + testee = im.let("val", im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.tuple_get(0, "val") + ) + expected = im.if_( + "cond", im.tuple_get(0, im.make_tuple(1, 2)), im.tuple_get(0, im.make_tuple(3, 4)) + ) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=True, + flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES + | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + ) + assert actual == expected + + +def test_propagate_nested_lift(): + testee = im.let("a", im.let("b", 1)("a_val"))("a") + expected = im.let("b", 1)(im.let("a", "a_val")("a")) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET, + ) + assert actual == expected + + +def test_if_on_tuples_with_let(): + testee = im.let("val", im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.tuple_get(0, "val") + ) + expected = im.if_("cond", 1, 3) + actual = CollapseTuple.apply(testee, remove_letified_make_tuple_elements=False) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 065095e1c2..fb7720f4d7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -213,15 +213,14 @@ def is_let(node: ir.Expr): testee = im.plus( im.let( - "c", - im.let( - "a", - 1, - "b", - 2, - )(im.plus("a", "b")), - "d", - 3, + ( + "c", + im.let( + ("a", 1), + ("b", 2), + )(im.plus("a", "b")), + ), + ("d", 3), )(im.plus("c", "d")), 4, )