From fe49d606ecf152315ffd123dcb81a101ae627116 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 17 Oct 2023 13:59:47 +0200 Subject: [PATCH 01/21] Improve CollapseTuple pass --- .../iterator/transforms/collapse_tuple.py | 163 ++++++++++++++++-- .../iterator/transforms/inline_lambdas.py | 6 + .../next/iterator/transforms/pass_manager.py | 1 + .../iterator/transforms/propagate_deref.py | 6 +- tests/next_tests/exclusion_matrices.py | 2 +- .../transforms_tests/test_collapse_tuple.py | 113 +++++++++++- 6 files changed, 264 insertions(+), 27 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 86f21072e5..1905990d53 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -11,12 +11,16 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +import enum +from typing import Callable -from dataclasses import dataclass +import dataclasses from gt4py import eve +from gt4py.eve.utils import UIDGenerator from gt4py.next import type_inference -from gt4py.next.iterator import ir, type_inference as it_type_inference +from gt4py.next.iterator import ir, type_inference as it_type_inference, ir_makers as im +from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda def _get_tuple_size(type_: type_inference.Type) -> int: @@ -25,8 +29,29 @@ def _get_tuple_size(type_: type_inference.Type) -> int: ) return len(type_.dtype) +def _is_let(node: ir.Node) -> bool: + return isinstance(node, ir.FunCall) and isinstance(node.fun, ir.Lambda) -@dataclass(frozen=True) +def _is_if_call(node: ir.Expr): + return isinstance(node, ir.FunCall) and node.fun == im.ref("if_") + +def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): + return ir.FunCall( + fun=node.fun, + args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)] + ) + +def _is_trivial_make_tuple_call(node: ir.Expr): + if not (isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple")): + return False + if not all(isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg) for arg in node.args): + return False + return True + +def nlet(bindings: list[tuple[ir.Sym, str], ir.Expr]): + return im.let(*[el for tup in bindings for el in tup]) + +@dataclasses.dataclass(frozen=True) class CollapseTuple(eve.NodeTranslator): """ Simplifies `make_tuple`, `tuple_get` calls. @@ -35,11 +60,41 @@ class CollapseTuple(eve.NodeTranslator): - `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` """ + class Flag(enum.IntEnum): + #: `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` + COLLAPSE_MAKE_TUPLE_TUPLE_GET = 1 + #: `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` + COLLAPSE_TUPLE_GET_MAKE_TUPLE = 2 + #: `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` + PROPAGATE_TUPLE_GET = 4 + #: `{1, 2}` -> `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` + LETIFY_MAKE_TUPLE_ELEMENTS = 8 + #: TODO + INLINE_TRIVIAL_MAKE_TUPLE = 16 + #: TODO + PROPAGATE_TO_IF_ON_TUPLES = 32 + #: TODO + PROPAGATE_NESTED_LET=64 + #: TODO + INLINE_TRIVIAL_LET=128 + ignore_tuple_size: bool - collapse_make_tuple_tuple_get: bool - collapse_tuple_get_make_tuple: bool + flags: int = (Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET + | Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE + | Flag.PROPAGATE_TUPLE_GET + | Flag.LETIFY_MAKE_TUPLE_ELEMENTS + | Flag.INLINE_TRIVIAL_MAKE_TUPLE + | Flag.PROPAGATE_TO_IF_ON_TUPLES + | Flag.PROPAGATE_NESTED_LET + | Flag.INLINE_TRIVIAL_LET) - _node_types: dict[int, type_inference.Type] + PRESERVED_ANNEX_ATTRS = ("type",) + + # we use one UID generator per instance such that the generated ids are + # stable across multiple runs (required for caching to properly work) + _letify_make_tuple_uids: UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: UIDGenerator(prefix="_tuple_el") + ) @classmethod def apply( @@ -47,9 +102,8 @@ def apply( node: ir.Node, *, ignore_tuple_size: bool = False, - # the following options are mostly for allowing separate testing of the modes - collapse_make_tuple_tuple_get: bool = True, - collapse_tuple_get_make_tuple: bool = True, + # manually passing flags is mostly for allowing separate testing of the modes + flags = None ) -> ir.Node: """ Simplifies `make_tuple`, `tuple_get` calls. @@ -57,18 +111,18 @@ def apply( If `ignore_tuple_size`, apply the transformation even if length of the inner tuple is greater than the length of the outer tuple. """ - node_types = it_type_inference.infer_all(node) + it_type_inference.infer_all(node, save_to_annex=True) return cls( - ignore_tuple_size, - collapse_make_tuple_tuple_get, - collapse_tuple_get_make_tuple, - node_types, + ignore_tuple_size=ignore_tuple_size, + flags=flags or cls.flags ).visit(node) def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: + node = self.generic_visit(node) + if ( - self.collapse_make_tuple_tuple_get + self.flags & self.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET and node.fun == ir.SymRef(id="make_tuple") and all( isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get") @@ -86,12 +140,13 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: # tuple argument differs, just continue with the rest of the tree return self.generic_visit(node) - if self.ignore_tuple_size or _get_tuple_size(self._node_types[id(first_expr)]) == len( + if self.ignore_tuple_size or _get_tuple_size(first_expr.annex.type) == len( node.args ): return first_expr + if ( - self.collapse_tuple_get_make_tuple + self.flags & self.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE and node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[1], ir.FunCall) and node.args[1].fun == ir.SymRef(id="make_tuple") @@ -105,4 +160,76 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: make_tuple_call.args ), f"Index {idx} is out of bounds for tuple of size {len(make_tuple_call.args)}" return node.args[1].args[idx] - return self.generic_visit(node) + + if ( + self.flags & self.Flag.PROPAGATE_TUPLE_GET + and node.fun == ir.SymRef(id="tuple_get") + and isinstance(node.args[0], ir.Literal) # TODO: extend to general symbols as long as the tail call in the let does not capture + ): + # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` + if _is_let(node.args[1]): + idx, let_expr = node.args + return self.visit( + im.call(im.lambda_(*let_expr.fun.params)(im.tuple_get(idx, let_expr.fun.expr)))(*let_expr.args) + ) + elif isinstance(node.args[1], ir.FunCall) and node.args[1].fun == im.ref("if_"): + idx = node.args[0] + cond, true_branch, false_branch = node.args[1].args + return self.visit( + im.if_(cond, im.tuple_get(idx, true_branch), im.tuple_get(idx, false_branch)) + ) # todo: check if visit needed + + if ( + self.flags & self.Flag.LETIFY_MAKE_TUPLE_ELEMENTS + and node.fun == ir.SymRef(id="make_tuple") + ): + bound_vars: dict[str, ir.Expr] = {} + new_args: list[ir.Expr] = [] + for i, arg in enumerate(node.args): + if isinstance(node, ir.FunCall) and node.fun == im.ref( + "make_tuple") and not _is_trivial_make_tuple_call(node): + el_name = self._letify_make_tuple_uids.sequential_id() + new_args.append(im.ref(el_name)) + bound_vars[el_name] = arg + else: + new_args.append(arg) + + if bound_vars: + return self.visit(im.let(*(el for item in bound_vars.items() for el in item))( + im.call(node.fun)(*new_args))) + + if self.flags & self.Flag.INLINE_TRIVIAL_MAKE_TUPLE and _is_let(node): + eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] + if any(eligible_params): + return self.visit(inline_lambda(node, eligible_params=eligible_params)) + + if self.flags & self.Flag.PROPAGATE_TO_IF_ON_TUPLES and not node.fun == im.ref("if_"): + # TODO: only do if type of branch value is a tuple + for i, arg in enumerate(node.args): + if _is_if_call(arg): + cond, true_branch, false_branch = arg.args + new_true_branch = self.visit(_with_altered_arg(node, i, true_branch), **kwargs) + new_false_branch = self.visit(_with_altered_arg(node, i, false_branch), **kwargs) + return im.if_(cond, new_true_branch, new_false_branch) + + if self.flags & self.Flag.PROPAGATE_NESTED_LET and _is_let(node): + outer_vars = {} + inner_vars = {} + original_inner_expr = node.fun.expr + for arg_sym, arg in zip(node.fun.params, node.args): + if _is_let(arg): + for sym, val in zip(arg.fun.params, arg.args): + outer_vars[sym] = val + assert arg_sym not in inner_vars # TODO: fix collisions + inner_vars[arg_sym] = arg.fun.expr + else: + inner_vars[arg_sym] = arg + if outer_vars: + node = self.visit(nlet(tuple(outer_vars.items()))(nlet(tuple(inner_vars.items()))(original_inner_expr))) + + if self.flags & self.Flag.INLINE_TRIVIAL_LET and _is_let(node) and isinstance(node.fun.expr, ir.SymRef): + for arg_sym, arg in zip(node.fun.params, node.args): + if node.fun.expr == im.ref(arg_sym.id): + return arg + + return node \ No newline at end of file diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index fc268f85e3..ca638e61c6 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -59,6 +59,12 @@ def inline_lambda( # noqa: C901 # see todo above if is_applied_lift(arg) and len(arg.args) == 0: eligible_params[i] = True + # TODO(tehrengruber): make configurable + if True: + for i, arg in enumerate(node.args): + if isinstance(arg, ir.Lambda): + eligible_params[i] = True + if node.fun.params and not any(eligible_params): return node diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 0ff3ec25c7..581f09aeab 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -107,6 +107,7 @@ def apply_common_transforms( # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. inlined = CollapseTuple.apply(inlined) + inlined = PropagateDeref.apply(inlined) # todo: document if inlined == ir: break diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 54bdafcda8..cae16de31e 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -14,7 +14,7 @@ from gt4py.eve import NodeTranslator from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir, ir_makers as im # TODO(tehrengruber): This pass can be generalized to all builtins, e.g. @@ -56,4 +56,8 @@ def visit_FunCall(self, node: ir.FunCall): ), args=lambda_args, ) + elif node.fun == im.ref("deref") and isinstance(node.args[0], ir.FunCall) and node.args[0].fun == im.ref("if_"): + cond, true_branch, false_branch = node.args[0].args + return im.if_(cond, im.deref(true_branch), im.deref(false_branch)) + return self.generic_visit(node) diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index d0a44080ad..db80fe8487 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -55,7 +55,7 @@ GTFN_SKIP_TEST_LIST = [ (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), + #(USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 736bf04d64..fda9578a2b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -22,7 +22,7 @@ def test_simple_make_tuple_tuple_get(): tuple_of_size_2 = im.make_tuple("first", "second") testee = im.make_tuple(im.tuple_get(0, tuple_of_size_2), im.tuple_get(1, tuple_of_size_2)) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET) expected = tuple_of_size_2 assert actual == expected @@ -34,7 +34,7 @@ def test_nested_make_tuple_tuple_get(): im.tuple_get(0, tup_of_size2_from_lambda), im.tuple_get(1, tup_of_size2_from_lambda) ) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET) assert actual == tup_of_size2_from_lambda @@ -44,7 +44,7 @@ def test_different_tuples_make_tuple_tuple_get(): t1 = im.make_tuple("foo1", "bar1") testee = im.make_tuple(im.tuple_get(0, t0), im.tuple_get(1, t1)) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET) assert actual == testee # did nothing @@ -52,24 +52,123 @@ def test_different_tuples_make_tuple_tuple_get(): def test_incompatible_order_make_tuple_tuple_get(): tuple_of_size_2 = im.make_tuple("first", "second") testee = im.make_tuple(im.tuple_get(1, tuple_of_size_2), im.tuple_get(0, tuple_of_size_2)) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET) assert actual == testee # did nothing def test_incompatible_size_make_tuple_tuple_get(): testee = im.make_tuple(im.tuple_get(0, im.make_tuple("first", "second"))) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET) assert actual == testee # did nothing def test_merged_with_smaller_outer_size_make_tuple_tuple_get(): testee = im.make_tuple(im.tuple_get(0, im.make_tuple("first", "second"))) - actual = CollapseTuple.apply(testee, ignore_tuple_size=True) + actual = CollapseTuple.apply(testee, ignore_tuple_size=True, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET) assert actual == im.make_tuple("first", "second") def test_simple_tuple_get_make_tuple(): expected = im.ref("bar") testee = im.tuple_get(1, im.make_tuple("foo", expected)) - actual = CollapseTuple.apply(testee, collapse_make_tuple_tuple_get=False) + actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE) assert expected == actual + +def test_propagate_tuple_get(): + expected = im.let("el1", 1, "el2", 2)(im.tuple_get(0, im.make_tuple("el1", "el2"))) + testee = im.tuple_get(0, im.let("el1", 1, "el2", 2)(im.make_tuple("el1", "el2"))) + actual = CollapseTuple.apply( + testee, + flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET + ) + assert expected == actual + +def test_letify_make_tuple_elements(): + opaque_call = im.call("opaque")() + testee = im.make_tuple(opaque_call, opaque_call) + expected = im.let("_tuple_el_1", opaque_call, "_tuple_el_2", opaque_call)( + im.make_tuple("_tuple_el_1", "_tuple_el_2") + ) + + actual = CollapseTuple.apply( + testee, + flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS + ) + assert actual == expected + + +def test_letify_make_tuple_with_trivial_elements(): + testee = im.let("a", 1, "b", 2)( + im.make_tuple("a", "b") + ) + expected = testee # did nothing + + actual = CollapseTuple.apply( + testee, + flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS + ) + assert actual == expected + +def test_inline_trivial_make_tuple(): + testee = im.let("tup", im.make_tuple("a", "b"))("tup") + expected = im.make_tuple("a", "b") + + actual = CollapseTuple.apply( + testee, + flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE + ) + assert actual == expected + +def test_propagate_to_if_on_tuples(): + testee = im.tuple_get(0, + im.if_("cond", + im.make_tuple(1, 2), + im.make_tuple(3, 4))) + expected = im.if_("cond", + im.tuple_get(0, + im.make_tuple(1, 2)), + im.tuple_get(0, + im.make_tuple(3, 4))) + actual = CollapseTuple.apply( + testee, + flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES + ) + assert actual == expected + +def test_propagate_to_if_on_tuples_with_let(): + testee = im.let("val", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.tuple_get(0, "val") + ) + expected = im.if_("cond", + im.tuple_get(0, + im.make_tuple(1, 2)), + im.tuple_get(0, + im.make_tuple(3, 4))) + actual = CollapseTuple.apply( + testee, + #flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES + ) + assert actual == expected + + +def test_propagate_nested_lift(): + testee = im.let("a", im.let("b", 1)("a_val"))("a") + expected = im.let("b", 1)(im.let("a", "a_val")("a")) + actual = CollapseTuple.apply( + testee, + flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET + ) + assert actual == expected + + +def test_collapse_complicated_(): + # TODO: fuse with test_propagate_to_if_on_tuples_with_let + testee = im.let("val", im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.tuple_get(0, "val") + ) + expected = im.if_("cond", 1, 3) + actual = CollapseTuple.apply( + testee, + #flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES + ) + assert actual == expected \ No newline at end of file From cd7b6e39e5663b20e8990f62bb078d746bc4531d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 17 Oct 2023 16:17:42 +0200 Subject: [PATCH 02/21] Add comments --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 1905990d53..234a6a1d0a 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -183,6 +183,8 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: self.flags & self.Flag.LETIFY_MAKE_TUPLE_ELEMENTS and node.fun == ir.SymRef(id="make_tuple") ): + # `make_tuple(expr1, expr1)` + # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` bound_vars: dict[str, ir.Expr] = {} new_args: list[ir.Expr] = [] for i, arg in enumerate(node.args): @@ -199,12 +201,16 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: im.call(node.fun)(*new_args))) if self.flags & self.Flag.INLINE_TRIVIAL_MAKE_TUPLE and _is_let(node): + # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(tup)` + # -> `make_tuple(trivial_expr1, trivial_expr2)` eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] if any(eligible_params): return self.visit(inline_lambda(node, eligible_params=eligible_params)) if self.flags & self.Flag.PROPAGATE_TO_IF_ON_TUPLES and not node.fun == im.ref("if_"): - # TODO: only do if type of branch value is a tuple + # TODO(tehrengruber): This significantly increases the size of the tree. Revisit. + # TODO(tehrengruber): Only inline if type of branch value is a tuple. + # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` for i, arg in enumerate(node.args): if _is_if_call(arg): cond, true_branch, false_branch = arg.args @@ -213,6 +219,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: return im.if_(cond, new_true_branch, new_false_branch) if self.flags & self.Flag.PROPAGATE_NESTED_LET and _is_let(node): + # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} inner_vars = {} original_inner_expr = node.fun.expr @@ -228,6 +235,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: node = self.visit(nlet(tuple(outer_vars.items()))(nlet(tuple(inner_vars.items()))(original_inner_expr))) if self.flags & self.Flag.INLINE_TRIVIAL_LET and _is_let(node) and isinstance(node.fun.expr, ir.SymRef): + # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): if node.fun.expr == im.ref(arg_sym.id): return arg From e49eae0765228b6bdb2da17397514254fd320597 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 19 Oct 2023 01:14:17 +0200 Subject: [PATCH 03/21] Bugfixes for scan --- .../iterator/transforms/collapse_tuple.py | 23 +++++++++++++++---- .../iterator/transforms/inline_lambdas.py | 8 ++++++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 234a6a1d0a..5d0c515f8e 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -20,7 +20,7 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import type_inference from gt4py.next.iterator import ir, type_inference as it_type_inference, ir_makers as im -from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda +from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda, InlineLambdas def _get_tuple_size(type_: type_inference.Type) -> int: @@ -113,13 +113,27 @@ def apply( """ it_type_inference.infer_all(node, save_to_annex=True) - return cls( + new_node = cls( ignore_tuple_size=ignore_tuple_size, flags=flags or cls.flags ).visit(node) + # inline lambdas to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS + # TODO: test case for `scan(lambda carry: {1, 2})` (see solve_nonhydro_stencil_52_like_z_q_tup) + new_node = InlineLambdas.apply(new_node, opcount_preserving=True, force_inline_lambda_args=False) + + # rerun with only some parts as LETIFY_MAKE_TUPLE_ELEMENTS might mess up the tree + # see `test_solve_nonhydro_stencil_52_like` + new_node = cls( + ignore_tuple_size=ignore_tuple_size, + flags=(cls.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET + | cls.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE) & (flags or cls.flags) + ).visit(node) + + return new_node + def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: - node = self.generic_visit(node) + node = self.generic_visit(node, **kwargs) if ( self.flags & self.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET @@ -224,10 +238,11 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: inner_vars = {} original_inner_expr = node.fun.expr for arg_sym, arg in zip(node.fun.params, node.args): + assert arg_sym not in inner_vars # TODO: fix collisions if _is_let(arg): for sym, val in zip(arg.fun.params, arg.args): + assert sym not in outer_vars # TODO: fix collisions outer_vars[sym] = val - assert arg_sym not in inner_vars # TODO: fix collisions inner_vars[arg_sym] = arg.fun.expr else: inner_vars[arg_sym] = arg diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index ca638e61c6..f22d1385d0 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -29,6 +29,7 @@ def inline_lambda( # noqa: C901 # see todo above opcount_preserving=False, force_inline_lift_args=False, force_inline_trivial_lift_args=False, + force_inline_lambda_args=True, eligible_params: Optional[list[bool]] = None, ): assert isinstance(node.fun, ir.Lambda) @@ -60,7 +61,7 @@ def inline_lambda( # noqa: C901 # see todo above eligible_params[i] = True # TODO(tehrengruber): make configurable - if True: + if force_inline_lambda_args: for i, arg in enumerate(node.args): if isinstance(arg, ir.Lambda): eligible_params[i] = True @@ -130,6 +131,8 @@ class InlineLambdas(NodeTranslator): force_inline_trivial_lift_args: bool + force_inline_lambda_args: bool + @classmethod def apply( cls, @@ -137,6 +140,7 @@ def apply( opcount_preserving=False, force_inline_lift_args=False, force_inline_trivial_lift_args=False, + force_inline_lambda_args=True, ): """ Inline lambda calls by substituting every argument by its value. @@ -162,6 +166,7 @@ def apply( opcount_preserving=opcount_preserving, force_inline_lift_args=force_inline_lift_args, force_inline_trivial_lift_args=force_inline_trivial_lift_args, + force_inline_lambda_args=force_inline_lambda_args, ).visit(node) def visit_FunCall(self, node: ir.FunCall): @@ -172,6 +177,7 @@ def visit_FunCall(self, node: ir.FunCall): opcount_preserving=self.opcount_preserving, force_inline_lift_args=self.force_inline_lift_args, force_inline_trivial_lift_args=self.force_inline_trivial_lift_args, + force_inline_lambda_args=self.force_inline_lambda_args, ) return node From 4e750b44efac21e62ad9222e2d838487e7f1f629 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sat, 21 Oct 2023 13:42:05 +0200 Subject: [PATCH 04/21] Introduce `_is_equal_value_heuristics` to avoid double visit --- .../iterator/transforms/collapse_tuple.py | 63 +++++++++++++++---- 1 file changed, 51 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 5d0c515f8e..b1edcdc9b2 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -12,7 +12,9 @@ # # SPDX-License-Identifier: GPL-3.0-or-later import enum +from collections import ChainMap from typing import Callable +import hashlib import dataclasses @@ -51,6 +53,49 @@ def _is_trivial_make_tuple_call(node: ir.Expr): def nlet(bindings: list[tuple[ir.Sym, str], ir.Expr]): return im.let(*[el for tup in bindings for el in tup]) +def _short_hash(val: str) -> str: + return hashlib.sha1(val.encode('UTF-8')).hexdigest()[0:6] + +@dataclasses.dataclass(frozen=True) +class CannonicalizeBoundSymbolNames(eve.NodeTranslator): + """ + Given an iterator expression cannonicalize all bound symbol names. + + If two such expression are in the same scope and equal so are their values. + + >>> testee1 = im.lambda_("a")(im.plus("a", "b")) + >>> cannonicalized_testee1 = CannonicalizeBoundSymbolNames.apply(testee1) + >>> str(cannonicalized_testee1) + 'λ(_csym_1) → _csym_1 + b' + + >>> testee2 = im.lambda_("c")(im.plus("c", "b")) + >>> cannonicalized_testee2 = CannonicalizeBoundSymbolNames.apply(testee2) + >>> assert cannonicalized_testee1 == cannonicalized_testee2 + """ + _uids: UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: UIDGenerator(prefix="_csym") + ) + + @classmethod + def apply(cls, node: ir.Expr): + return cls().visit(node, sym_map=ChainMap({})) + + def visit_Lambda(self, node: ir.Lambda, *, sym_map: ChainMap): + sym_map = sym_map.new_child() + for param in node.params: + sym_map[str(param.id)] = self._uids.sequential_id() + + return im.lambda_(*sym_map.values())(self.visit(node.expr, sym_map=sym_map)) + + def visit_SymRef(self, node: ir.SymRef, *, sym_map: dict[str, str]): + return im.ref(sym_map[node.id]) if node.id in sym_map else node + +def _is_equal_value_heuristics(a: ir.Expr, b: ir.Expr): + """ + Return true if, bot not only if, two expression (with equal scope) have the same value. + """ + return a == b or (CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b)) + @dataclasses.dataclass(frozen=True) class CollapseTuple(eve.NodeTranslator): """ @@ -118,18 +163,12 @@ def apply( flags=flags or cls.flags ).visit(node) - # inline lambdas to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS + # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important + # as otherwise two equal expressions containing a tuple will not be equal anymore + # and the CSE pass can not remove them. # TODO: test case for `scan(lambda carry: {1, 2})` (see solve_nonhydro_stencil_52_like_z_q_tup) new_node = InlineLambdas.apply(new_node, opcount_preserving=True, force_inline_lambda_args=False) - # rerun with only some parts as LETIFY_MAKE_TUPLE_ELEMENTS might mess up the tree - # see `test_solve_nonhydro_stencil_52_like` - new_node = cls( - ignore_tuple_size=ignore_tuple_size, - flags=(cls.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET - | cls.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE) & (flags or cls.flags) - ).visit(node) - return new_node def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: @@ -150,7 +189,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: for i, v in enumerate(node.args): assert isinstance(v, ir.FunCall) assert isinstance(v.args[0], ir.Literal) - if not (int(v.args[0].value) == i and v.args[1] == first_expr): + if not (int(v.args[0].value) == i and _is_equal_value_heuristics(v.args[1], first_expr)): # tuple argument differs, just continue with the rest of the tree return self.generic_visit(node) @@ -215,8 +254,8 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: im.call(node.fun)(*new_args))) if self.flags & self.Flag.INLINE_TRIVIAL_MAKE_TUPLE and _is_let(node): - # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(tup)` - # -> `make_tuple(trivial_expr1, trivial_expr2)` + # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` + # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] if any(eligible_params): return self.visit(inline_lambda(node, eligible_params=eligible_params)) From e6c9f443c667dc73ca7d2fb806202c810f76bc3c Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 28 Nov 2023 14:46:23 +0100 Subject: [PATCH 05/21] Fix tests --- .../iterator_tests/transforms_tests/test_collapse_tuple.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 57ac9a6f12..6693805e56 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -135,6 +135,7 @@ def test_propagate_to_if_on_tuples(): ) assert actual == expected + def test_propagate_to_if_on_tuples_with_let(): testee = im.let("val", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( im.tuple_get(0, "val") @@ -146,7 +147,7 @@ def test_propagate_to_if_on_tuples_with_let(): im.make_tuple(3, 4))) actual = CollapseTuple.apply( testee, - flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES + flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS ) assert actual == expected From 5a275243abee4f8e9b7589785248ad2299d9bfda Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 28 Nov 2023 15:27:09 +0100 Subject: [PATCH 06/21] Fix tests --- src/gt4py/next/ffront/decorator.py | 2 +- src/gt4py/next/ffront/foast_to_itir.py | 3 +- src/gt4py/next/iterator/ir_utils/__init__.py | 0 .../common_pattern_matcher.py | 11 +++ .../next/iterator/{ => ir_utils}/ir_makers.py | 0 .../ir_utils/is_equal_value_heuristics.py | 53 ++++++++++++++ src/gt4py/next/iterator/pretty_parser.py | 4 +- src/gt4py/next/iterator/tracing.py | 3 +- .../iterator/transforms/collapse_tuple.py | 73 ++++--------------- .../iterator/transforms/constant_folding.py | 3 +- .../next/iterator/transforms/global_tmps.py | 5 +- .../iterator/transforms/inline_lambdas.py | 2 +- .../next/iterator/transforms/inline_lifts.py | 3 +- .../iterator/transforms/propagate_deref.py | 3 +- .../next/iterator/transforms/unroll_reduce.py | 2 +- .../ffront_tests/test_foast_to_itir.py | 3 +- .../iterator_tests/test_type_inference.py | 3 +- .../transforms_tests/test_collapse_tuple.py | 4 +- .../transforms_tests/test_constant_folding.py | 2 +- .../transforms_tests/test_cse.py | 3 +- .../transforms_tests/test_global_tmps.py | 3 +- .../transforms_tests/test_inline_lambdas.py | 2 +- .../transforms_tests/test_inline_lifts.py | 2 +- .../transforms_tests/test_propagate_deref.py | 2 +- .../transforms_tests/test_trace_shifts.py | 3 +- 25 files changed, 111 insertions(+), 83 deletions(-) create mode 100644 src/gt4py/next/iterator/ir_utils/__init__.py rename src/gt4py/next/iterator/{transforms => ir_utils}/common_pattern_matcher.py (69%) rename src/gt4py/next/iterator/{ => ir_utils}/ir_makers.py (100%) create mode 100644 src/gt4py/next/iterator/ir_utils/is_equal_value_heuristics.py diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 107415eb06..2e872ab608 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -53,7 +53,7 @@ from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_makers import literal_from_value, promote_to_const_iterator, ref, sym +from gt4py.next.iterator.ir_utils.ir_makers import literal_from_value, promote_to_const_iterator, ref, sym from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.program_processors.runners import roundtrip from gt4py.next.type_system import type_info, type_specifications as ts, type_translation diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 816b8581f1..3030c03fd1 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -25,7 +25,8 @@ ) from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind -from gt4py.next.iterator import ir as itir, ir_makers as im +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_info, type_specifications as ts diff --git a/src/gt4py/next/iterator/ir_utils/__init__.py b/src/gt4py/next/iterator/ir_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/gt4py/next/iterator/transforms/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py similarity index 69% rename from src/gt4py/next/iterator/transforms/common_pattern_matcher.py rename to src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 8df4723502..66a35e32af 100644 --- a/src/gt4py/next/iterator/transforms/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -14,6 +14,7 @@ from typing import TypeGuard from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: @@ -24,3 +25,13 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: and isinstance(arg.fun.fun, itir.SymRef) and arg.fun.fun.id == "lift" ) + + +def is_let(node: itir.Node) -> bool: + """Match expression of the form `(λ(...) → ...)(...)`""" + return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) + + +def is_if_call(node: itir.Expr): + """Match expression of the form `if_(cond, true_branch, false_branch)`""" + return isinstance(node, itir.FunCall) and node.fun == im.ref("if_") \ No newline at end of file diff --git a/src/gt4py/next/iterator/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py similarity index 100% rename from src/gt4py/next/iterator/ir_makers.py rename to src/gt4py/next/iterator/ir_utils/ir_makers.py diff --git a/src/gt4py/next/iterator/ir_utils/is_equal_value_heuristics.py b/src/gt4py/next/iterator/ir_utils/is_equal_value_heuristics.py new file mode 100644 index 0000000000..6b153edebc --- /dev/null +++ b/src/gt4py/next/iterator/ir_utils/is_equal_value_heuristics.py @@ -0,0 +1,53 @@ +import dataclasses +from collections import ChainMap + +from gt4py import eve +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im + + +@dataclasses.dataclass(frozen=True) +class CannonicalizeBoundSymbolNames(eve.NodeTranslator): + """ + Given an iterator expression cannonicalize all bound symbol names. + + If two such expression are in the same scope and equal so are their values. + + >>> testee1 = im.lambda_("a")(im.plus("a", "b")) + >>> cannonicalized_testee1 = CannonicalizeBoundSymbolNames.apply(testee1) + >>> str(cannonicalized_testee1) + 'λ(_csym_1) → _csym_1 + b' + + >>> testee2 = im.lambda_("c")(im.plus("c", "b")) + >>> cannonicalized_testee2 = CannonicalizeBoundSymbolNames.apply(testee2) + >>> assert cannonicalized_testee1 == cannonicalized_testee2 + """ + _uids: eve_utils.UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_csym") + ) + + @classmethod + def apply(cls, node: itir.Expr): + return cls().visit(node, sym_map=ChainMap({})) + + def visit_Lambda(self, node: itir.Lambda, *, sym_map: ChainMap): + sym_map = sym_map.new_child() + for param in node.params: + sym_map[str(param.id)] = self._uids.sequential_id() + + return im.lambda_(*sym_map.values())(self.visit(node.expr, sym_map=sym_map)) + + def visit_SymRef(self, node: itir.SymRef, *, sym_map: dict[str, str]): + return im.ref(sym_map[node.id]) if node.id in sym_map else node + + +def is_equal_value_heuristics(a: itir.Expr, b: itir.Expr): + """ + Return true if, bot not only if, two expression (with equal scope) have the same value. + + >>> testee1 = im.lambda_("a")(im.plus("a", "b")) + >>> testee2 = im.lambda_("c")(im.plus("c", "b")) + >>> assert is_equal_value_heuristics(testee1, testee2) + """ + return a == b or (CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b)) \ No newline at end of file diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index a541e985ad..78f6fc9e45 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -16,8 +16,8 @@ from lark import lark, lexer as lark_lexer, visitors as lark_visitors -from gt4py.next.iterator import ir, ir_makers as im - +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im GRAMMAR = """ start: fencil_definition diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index fbe6a2ae82..30d3c3225f 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -20,7 +20,8 @@ from gt4py._core import definitions as core_defs from gt4py.eve import Node from gt4py.next import common, iterator -from gt4py.next.iterator import builtins, ir_makers as im +from gt4py.next.iterator import builtins +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir import ( AxisLiteral, Expr, diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 2651b0f46b..98da52ee91 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -12,16 +12,16 @@ # # SPDX-License-Identifier: GPL-3.0-or-later import enum -from collections import ChainMap -from typing import Callable, Optional -import hashlib +from typing import Optional import dataclasses from gt4py import eve -from gt4py.eve.utils import UIDGenerator +import gt4py.eve.utils from gt4py.next import type_inference -from gt4py.next.iterator import ir, type_inference as it_type_inference, ir_makers as im +from gt4py.next.iterator import ir, type_inference as it_type_inference +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_let, is_if_call from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda, InlineLambdas @@ -49,13 +49,9 @@ def _get_tuple_size(elem: ir.Node, use_global_information: bool) -> int | type[U return len(type_.dtype) -def _is_let(node: ir.Node) -> bool: - return isinstance(node, ir.FunCall) and isinstance(node.fun, ir.Lambda) - -def _is_if_call(node: ir.Expr): - return isinstance(node, ir.FunCall) and node.fun == im.ref("if_") def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): + """Given a itir.FunCall return a new call with one of its argument replaced.""" return ir.FunCall( fun=node.fun, args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)] @@ -71,48 +67,7 @@ def _is_trivial_make_tuple_call(node: ir.Expr): def nlet(bindings: list[tuple[ir.Sym, str], ir.Expr]): return im.let(*[el for tup in bindings for el in tup]) -def _short_hash(val: str) -> str: - return hashlib.sha1(val.encode('UTF-8')).hexdigest()[0:6] - -@dataclasses.dataclass(frozen=True) -class CannonicalizeBoundSymbolNames(eve.NodeTranslator): - """ - Given an iterator expression cannonicalize all bound symbol names. - - If two such expression are in the same scope and equal so are their values. - - >>> testee1 = im.lambda_("a")(im.plus("a", "b")) - >>> cannonicalized_testee1 = CannonicalizeBoundSymbolNames.apply(testee1) - >>> str(cannonicalized_testee1) - 'λ(_csym_1) → _csym_1 + b' - - >>> testee2 = im.lambda_("c")(im.plus("c", "b")) - >>> cannonicalized_testee2 = CannonicalizeBoundSymbolNames.apply(testee2) - >>> assert cannonicalized_testee1 == cannonicalized_testee2 - """ - _uids: UIDGenerator = dataclasses.field( - init=False, repr=False, default_factory=lambda: UIDGenerator(prefix="_csym") - ) - - @classmethod - def apply(cls, node: ir.Expr): - return cls().visit(node, sym_map=ChainMap({})) - - def visit_Lambda(self, node: ir.Lambda, *, sym_map: ChainMap): - sym_map = sym_map.new_child() - for param in node.params: - sym_map[str(param.id)] = self._uids.sequential_id() - return im.lambda_(*sym_map.values())(self.visit(node.expr, sym_map=sym_map)) - - def visit_SymRef(self, node: ir.SymRef, *, sym_map: dict[str, str]): - return im.ref(sym_map[node.id]) if node.id in sym_map else node - -def _is_equal_value_heuristics(a: ir.Expr, b: ir.Expr): - """ - Return true if, bot not only if, two expression (with equal scope) have the same value. - """ - return a == b or (CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b)) @dataclasses.dataclass(frozen=True) class CollapseTuple(eve.NodeTranslator): @@ -156,8 +111,8 @@ class Flag(enum.IntEnum): # we use one UID generator per instance such that the generated ids are # stable across multiple runs (required for caching to properly work) - _letify_make_tuple_uids: UIDGenerator = dataclasses.field( - init=False, repr=False, default_factory=lambda: UIDGenerator(prefix="_tuple_el") + _letify_make_tuple_uids: eve.utils.UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="_tuple_el") ) _node_types: Optional[dict[int, type_inference.Type]] = None @@ -247,7 +202,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: and isinstance(node.args[0], ir.Literal) # TODO: extend to general symbols as long as the tail call in the let does not capture ): # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` - if _is_let(node.args[1]): + if is_let(node.args[1]): idx, let_expr = node.args return self.visit( im.call(im.lambda_(*let_expr.fun.params)(im.tuple_get(idx, let_expr.fun.expr)))(*let_expr.args) @@ -280,7 +235,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: return self.visit(im.let(*(el for item in bound_vars.items() for el in item))( im.call(node.fun)(*new_args))) - if self.flags & self.Flag.INLINE_TRIVIAL_MAKE_TUPLE and _is_let(node): + if self.flags & self.Flag.INLINE_TRIVIAL_MAKE_TUPLE and is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] @@ -292,20 +247,20 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: # TODO(tehrengruber): Only inline if type of branch value is a tuple. # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` for i, arg in enumerate(node.args): - if _is_if_call(arg): + if is_if_call(arg): cond, true_branch, false_branch = arg.args new_true_branch = self.visit(_with_altered_arg(node, i, true_branch), **kwargs) new_false_branch = self.visit(_with_altered_arg(node, i, false_branch), **kwargs) return im.if_(cond, new_true_branch, new_false_branch) - if self.flags & self.Flag.PROPAGATE_NESTED_LET and _is_let(node): + if self.flags & self.Flag.PROPAGATE_NESTED_LET and is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} inner_vars = {} original_inner_expr = node.fun.expr for arg_sym, arg in zip(node.fun.params, node.args): assert arg_sym not in inner_vars # TODO: fix collisions - if _is_let(arg): + if is_let(arg): for sym, val in zip(arg.fun.params, arg.args): assert sym not in outer_vars # TODO: fix collisions outer_vars[sym] = val @@ -315,7 +270,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: if outer_vars: node = self.visit(nlet(tuple(outer_vars.items()))(nlet(tuple(inner_vars.items()))(original_inner_expr))) - if self.flags & self.Flag.INLINE_TRIVIAL_LET and _is_let(node) and isinstance(node.fun.expr, ir.SymRef): + if self.flags & self.Flag.INLINE_TRIVIAL_LET and is_let(node) and isinstance(node.fun.expr, ir.SymRef): # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): if node.fun.expr == im.ref(arg_sym.id): diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index cda422f30d..fa326760b0 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -13,7 +13,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator -from gt4py.next.iterator import embedded, ir, ir_makers as im +from gt4py.next.iterator import embedded, ir +from gt4py.next.iterator.ir_utils import ir_makers as im class ConstantFolding(NodeTranslator): diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index e1b697e0bc..4faafa96f5 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -22,10 +22,11 @@ from gt4py.eve import Coerced, NodeTranslator from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator -from gt4py.next.iterator import ir, ir_makers as im, type_inference +from gt4py.next.iterator import ir, type_inference +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.pretty_printer import PrettyPrinter from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.cse import extract_subexpression from gt4py.next.iterator.transforms.eta_reduction import EtaReduction from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index f22d1385d0..8225119c33 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -17,7 +17,7 @@ from gt4py.eve import NodeTranslator from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 8d62450e67..d7d8e5e612 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -19,7 +19,8 @@ import gt4py.eve as eve from gt4py.eve import NodeTranslator, traits -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index cae16de31e..b01c49b4d9 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -14,7 +14,8 @@ from gt4py.eve import NodeTranslator from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im # TODO(tehrengruber): This pass can be generalized to all builtins, e.g. diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index e3084eaba5..60a5db7e96 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -20,7 +20,7 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift def _is_shifted(arg: itir.Expr) -> TypeGuard[itir.FunCall]: diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index dd66beb522..2dd4b91c48 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -29,7 +29,8 @@ from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering from gt4py.next.ffront.func_to_foast import FieldOperatorParser -from gt4py.next.iterator import ir as itir, ir_makers as im +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts, type_translation diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 1526e97d74..cacdb7b070 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -15,7 +15,8 @@ import numpy as np import gt4py.next as gtx -from gt4py.next.iterator import ir, ir_makers as im, type_inference as ti +from gt4py.next.iterator import ir, type_inference as ti +from gt4py.next.iterator.ir_utils import ir_makers as im def test_unsatisfiable_constraints(): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 6693805e56..e42c08194d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -12,9 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import pytest - -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 5d052b1989..275412a537 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.constant_folding import ConstantFolding diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 5d9e0933a7..065095e1c2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -14,7 +14,8 @@ import textwrap from gt4py.eve.utils import UIDGenerator -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.cse import ( CommonSubexpressionElimination as CSE, extract_subexpression, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 88f6ed517b..86c3c98c62 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -15,7 +15,8 @@ import gt4py.next as gtx from gt4py.eve.utils import UIDs -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.global_tmps import ( AUTO_DOMAIN, FencilWithTemporaries, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index b9f2ca16a1..88e554f349 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -14,7 +14,7 @@ import pytest -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py index 1da2b8a044..e1d440044d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py @@ -14,7 +14,7 @@ import pytest -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lifts import InlineLifts diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py index ffbf2c2c8e..e2e29cd4db 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py index 2624a17ebd..47db632a5e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py @@ -12,7 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.trace_shifts import Sentinel, TraceShifts From 0437fced486d9e267d3c2e972a2d790f634cf704 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 28 Nov 2023 15:39:12 +0100 Subject: [PATCH 07/21] Move iterator utils to dedicated module (in preparation for other PR with more utils) --- src/gt4py/next/ffront/decorator.py | 7 ++++++- src/gt4py/next/ffront/foast_to_itir.py | 3 ++- src/gt4py/next/iterator/ir_utils/__init__.py | 13 +++++++++++++ .../common_pattern_matcher.py | 0 src/gt4py/next/iterator/{ => ir_utils}/ir_makers.py | 0 src/gt4py/next/iterator/pretty_parser.py | 3 ++- src/gt4py/next/iterator/tracing.py | 3 ++- .../next/iterator/transforms/constant_folding.py | 3 ++- src/gt4py/next/iterator/transforms/global_tmps.py | 5 +++-- .../next/iterator/transforms/inline_lambdas.py | 2 +- src/gt4py/next/iterator/transforms/inline_lifts.py | 3 ++- src/gt4py/next/iterator/transforms/unroll_reduce.py | 2 +- .../unit_tests/ffront_tests/test_foast_to_itir.py | 3 ++- .../iterator_tests/test_type_inference.py | 3 ++- .../transforms_tests/test_collapse_tuple.py | 4 +--- .../transforms_tests/test_constant_folding.py | 2 +- .../iterator_tests/transforms_tests/test_cse.py | 3 ++- .../transforms_tests/test_global_tmps.py | 3 ++- .../transforms_tests/test_inline_lambdas.py | 2 +- .../transforms_tests/test_inline_lifts.py | 2 +- .../transforms_tests/test_propagate_deref.py | 2 +- .../transforms_tests/test_trace_shifts.py | 3 ++- 22 files changed, 49 insertions(+), 22 deletions(-) create mode 100644 src/gt4py/next/iterator/ir_utils/__init__.py rename src/gt4py/next/iterator/{transforms => ir_utils}/common_pattern_matcher.py (100%) rename src/gt4py/next/iterator/{ => ir_utils}/ir_makers.py (100%) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 67272f88b8..e06c651b13 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -53,7 +53,12 @@ from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_makers import literal_from_value, promote_to_const_iterator, ref, sym +from gt4py.next.iterator.ir_utils.ir_makers import ( + literal_from_value, + promote_to_const_iterator, + ref, + sym, +) from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.program_processors.runners import roundtrip from gt4py.next.type_system import type_info, type_specifications as ts, type_translation diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 816b8581f1..3030c03fd1 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -25,7 +25,8 @@ ) from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind -from gt4py.next.iterator import ir as itir, ir_makers as im +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_info, type_specifications as ts diff --git a/src/gt4py/next/iterator/ir_utils/__init__.py b/src/gt4py/next/iterator/ir_utils/__init__.py new file mode 100644 index 0000000000..6c43e2f12a --- /dev/null +++ b/src/gt4py/next/iterator/ir_utils/__init__.py @@ -0,0 +1,13 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/src/gt4py/next/iterator/transforms/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py similarity index 100% rename from src/gt4py/next/iterator/transforms/common_pattern_matcher.py rename to src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py diff --git a/src/gt4py/next/iterator/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py similarity index 100% rename from src/gt4py/next/iterator/ir_makers.py rename to src/gt4py/next/iterator/ir_utils/ir_makers.py diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index a541e985ad..2b1c8169fb 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -16,7 +16,8 @@ from lark import lark, lexer as lark_lexer, visitors as lark_visitors -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im GRAMMAR = """ diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index fbe6a2ae82..d1f6bba8d6 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -20,7 +20,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import Node from gt4py.next import common, iterator -from gt4py.next.iterator import builtins, ir_makers as im +from gt4py.next.iterator import builtins from gt4py.next.iterator.ir import ( AxisLiteral, Expr, @@ -34,6 +34,7 @@ Sym, SymRef, ) +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_info, type_specifications, type_translation diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index cda422f30d..fa326760b0 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -13,7 +13,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator -from gt4py.next.iterator import embedded, ir, ir_makers as im +from gt4py.next.iterator import embedded, ir +from gt4py.next.iterator.ir_utils import ir_makers as im class ConstantFolding(NodeTranslator): diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index e1b697e0bc..d9d3d18213 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -22,10 +22,11 @@ from gt4py.eve import Coerced, NodeTranslator from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator -from gt4py.next.iterator import ir, ir_makers as im, type_inference +from gt4py.next.iterator import ir, type_inference +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.pretty_printer import PrettyPrinter from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.cse import extract_subexpression from gt4py.next.iterator.transforms.eta_reduction import EtaReduction from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index fc268f85e3..eac4338345 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -17,7 +17,7 @@ from gt4py.eve import NodeTranslator from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 8d62450e67..d7d8e5e612 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -19,7 +19,8 @@ import gt4py.eve as eve from gt4py.eve import NodeTranslator, traits -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index e3084eaba5..60a5db7e96 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -20,7 +20,7 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift def _is_shifted(arg: itir.Expr) -> TypeGuard[itir.FunCall]: diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index dd66beb522..2dd4b91c48 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -29,7 +29,8 @@ from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering from gt4py.next.ffront.func_to_foast import FieldOperatorParser -from gt4py.next.iterator import ir as itir, ir_makers as im +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts, type_translation diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 1526e97d74..cacdb7b070 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -15,7 +15,8 @@ import numpy as np import gt4py.next as gtx -from gt4py.next.iterator import ir, ir_makers as im, type_inference as ti +from gt4py.next.iterator import ir, type_inference as ti +from gt4py.next.iterator.ir_utils import ir_makers as im def test_unsatisfiable_constraints(): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 736bf04d64..1444b0a64f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -12,9 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import pytest - -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 5d052b1989..275412a537 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.constant_folding import ConstantFolding diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 5d9e0933a7..065095e1c2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -14,7 +14,8 @@ import textwrap from gt4py.eve.utils import UIDGenerator -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.cse import ( CommonSubexpressionElimination as CSE, extract_subexpression, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 88f6ed517b..86c3c98c62 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -15,7 +15,8 @@ import gt4py.next as gtx from gt4py.eve.utils import UIDs -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.global_tmps import ( AUTO_DOMAIN, FencilWithTemporaries, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index b9f2ca16a1..88e554f349 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -14,7 +14,7 @@ import pytest -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py index 1da2b8a044..e1d440044d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py @@ -14,7 +14,7 @@ import pytest -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lifts import InlineLifts diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py index ffbf2c2c8e..e2e29cd4db 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py index 2624a17ebd..47db632a5e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py @@ -12,7 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.trace_shifts import Sentinel, TraceShifts From 36039a4ada9410d848f62d538688a77f7f90ba51 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 28 Nov 2023 15:43:08 +0100 Subject: [PATCH 08/21] Fix format --- src/gt4py/next/ffront/decorator.py | 7 +- src/gt4py/next/iterator/ir_utils/__init__.py | 13 +++ .../ir_utils/common_pattern_matcher.py | 2 +- .../ir_utils/is_equal_value_heuristics.py | 19 ++- src/gt4py/next/iterator/pretty_parser.py | 1 + src/gt4py/next/iterator/tracing.py | 2 +- .../iterator/transforms/collapse_tuple.py | 108 +++++++++++------- .../next/iterator/transforms/global_tmps.py | 2 +- .../next/iterator/transforms/pass_manager.py | 2 +- .../iterator/transforms/propagate_deref.py | 6 +- .../next/program_processors/runners/gtfn.py | 5 +- tests/next_tests/exclusion_matrices.py | 2 +- .../transforms_tests/test_collapse_tuple.py | 68 ++++------- 13 files changed, 141 insertions(+), 96 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 9494ed4c7b..e06c651b13 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -53,7 +53,12 @@ from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils.ir_makers import literal_from_value, promote_to_const_iterator, ref, sym +from gt4py.next.iterator.ir_utils.ir_makers import ( + literal_from_value, + promote_to_const_iterator, + ref, + sym, +) from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.program_processors.runners import roundtrip from gt4py.next.type_system import type_info, type_specifications as ts, type_translation diff --git a/src/gt4py/next/iterator/ir_utils/__init__.py b/src/gt4py/next/iterator/ir_utils/__init__.py index e69de29bb2..6c43e2f12a 100644 --- a/src/gt4py/next/iterator/ir_utils/__init__.py +++ b/src/gt4py/next/iterator/ir_utils/__init__.py @@ -0,0 +1,13 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 66a35e32af..661086aa51 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -34,4 +34,4 @@ def is_let(node: itir.Node) -> bool: def is_if_call(node: itir.Expr): """Match expression of the form `if_(cond, true_branch, false_branch)`""" - return isinstance(node, itir.FunCall) and node.fun == im.ref("if_") \ No newline at end of file + return isinstance(node, itir.FunCall) and node.fun == im.ref("if_") diff --git a/src/gt4py/next/iterator/ir_utils/is_equal_value_heuristics.py b/src/gt4py/next/iterator/ir_utils/is_equal_value_heuristics.py index 6b153edebc..4e4893721c 100644 --- a/src/gt4py/next/iterator/ir_utils/is_equal_value_heuristics.py +++ b/src/gt4py/next/iterator/ir_utils/is_equal_value_heuristics.py @@ -1,3 +1,17 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + import dataclasses from collections import ChainMap @@ -23,6 +37,7 @@ class CannonicalizeBoundSymbolNames(eve.NodeTranslator): >>> cannonicalized_testee2 = CannonicalizeBoundSymbolNames.apply(testee2) >>> assert cannonicalized_testee1 == cannonicalized_testee2 """ + _uids: eve_utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_csym") ) @@ -50,4 +65,6 @@ def is_equal_value_heuristics(a: itir.Expr, b: itir.Expr): >>> testee2 = im.lambda_("c")(im.plus("c", "b")) >>> assert is_equal_value_heuristics(testee1, testee2) """ - return a == b or (CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b)) \ No newline at end of file + return a == b or ( + CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b) + ) diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 78f6fc9e45..2b1c8169fb 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -19,6 +19,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ir_makers as im + GRAMMAR = """ start: fencil_definition | function_definition diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 30d3c3225f..d1f6bba8d6 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -21,7 +21,6 @@ from gt4py.eve import Node from gt4py.next import common, iterator from gt4py.next.iterator import builtins -from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir import ( AxisLiteral, Expr, @@ -35,6 +34,7 @@ Sym, SymRef, ) +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_info, type_specifications, type_translation diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 98da52ee91..0800ae708c 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -11,18 +11,17 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +import dataclasses import enum from typing import Optional -import dataclasses - -from gt4py import eve import gt4py.eve.utils +from gt4py import eve from gt4py.next import type_inference from gt4py.next.iterator import ir, type_inference as it_type_inference from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_let, is_if_call -from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda, InlineLambdas +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let +from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda class UnknownLength: @@ -53,22 +52,25 @@ def _get_tuple_size(elem: ir.Node, use_global_information: bool) -> int | type[U def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): """Given a itir.FunCall return a new call with one of its argument replaced.""" return ir.FunCall( - fun=node.fun, - args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)] + fun=node.fun, args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)] ) + def _is_trivial_make_tuple_call(node: ir.Expr): if not (isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple")): return False - if not all(isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg) for arg in node.args): + if not all( + isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg) + for arg in node.args + ): return False return True + def nlet(bindings: list[tuple[ir.Sym, str], ir.Expr]): return im.let(*[el for tup in bindings for el in tup]) - @dataclasses.dataclass(frozen=True) class CollapseTuple(eve.NodeTranslator): """ @@ -92,20 +94,22 @@ class Flag(enum.IntEnum): #: TODO PROPAGATE_TO_IF_ON_TUPLES = 32 #: TODO - PROPAGATE_NESTED_LET=64 + PROPAGATE_NESTED_LET = 64 #: TODO - INLINE_TRIVIAL_LET=128 + INLINE_TRIVIAL_LET = 128 ignore_tuple_size: bool use_global_type_inference: bool - flags: int = (Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET - | Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE - | Flag.PROPAGATE_TUPLE_GET - | Flag.LETIFY_MAKE_TUPLE_ELEMENTS - | Flag.INLINE_TRIVIAL_MAKE_TUPLE - | Flag.PROPAGATE_TO_IF_ON_TUPLES - | Flag.PROPAGATE_NESTED_LET - | Flag.INLINE_TRIVIAL_LET) + flags: int = ( + Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET + | Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE + | Flag.PROPAGATE_TUPLE_GET + | Flag.LETIFY_MAKE_TUPLE_ELEMENTS + | Flag.INLINE_TRIVIAL_MAKE_TUPLE + | Flag.PROPAGATE_TO_IF_ON_TUPLES + | Flag.PROPAGATE_NESTED_LET + | Flag.INLINE_TRIVIAL_LET + ) PRESERVED_ANNEX_ATTRS = ("type",) @@ -125,7 +129,7 @@ def apply( ignore_tuple_size: bool = False, use_global_type_inference: bool = False, # manually passing flags is mostly for allowing separate testing of the modes - flags = None + flags=None, ) -> ir.Node: """ Simplifies `make_tuple`, `tuple_get` calls. @@ -140,7 +144,7 @@ def apply( new_node = cls( ignore_tuple_size=ignore_tuple_size, use_global_type_inference=use_global_type_inference, - flags=flags + flags=flags, ).visit(node) # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important @@ -148,8 +152,9 @@ def apply( # and the CSE pass can not remove them. # TODO: test case for `scan(lambda carry: {1, 2})` (see solve_nonhydro_stencil_52_like_z_q_tup) if flags & cls.Flag.LETIFY_MAKE_TUPLE_ELEMENTS: - new_node = InlineLambdas.apply(new_node, opcount_preserving=True, - force_inline_lambda_args=False) + new_node = InlineLambdas.apply( + new_node, opcount_preserving=True, force_inline_lambda_args=False + ) return new_node @@ -171,13 +176,15 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: for i, v in enumerate(node.args): assert isinstance(v, ir.FunCall) assert isinstance(v.args[0], ir.Literal) - if not (int(v.args[0].value) == i and _is_equal_value_heuristics(v.args[1], first_expr)): + if not ( + int(v.args[0].value) == i and _is_equal_value_heuristics(v.args[1], first_expr) + ): # tuple argument differs, just continue with the rest of the tree return self.generic_visit(node) - if self.ignore_tuple_size or _get_tuple_size(first_expr, self.use_global_type_inference) == len( - node.args - ): + if self.ignore_tuple_size or _get_tuple_size( + first_expr, self.use_global_type_inference + ) == len(node.args): return first_expr if ( @@ -199,32 +206,38 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: if ( self.flags & self.Flag.PROPAGATE_TUPLE_GET and node.fun == ir.SymRef(id="tuple_get") - and isinstance(node.args[0], ir.Literal) # TODO: extend to general symbols as long as the tail call in the let does not capture + and isinstance( + node.args[0], ir.Literal + ) # TODO: extend to general symbols as long as the tail call in the let does not capture ): # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` if is_let(node.args[1]): idx, let_expr = node.args return self.visit( - im.call(im.lambda_(*let_expr.fun.params)(im.tuple_get(idx, let_expr.fun.expr)))(*let_expr.args) + im.call(im.lambda_(*let_expr.fun.params)(im.tuple_get(idx, let_expr.fun.expr)))( + *let_expr.args + ) ) elif isinstance(node.args[1], ir.FunCall) and node.args[1].fun == im.ref("if_"): idx = node.args[0] cond, true_branch, false_branch = node.args[1].args return self.visit( im.if_(cond, im.tuple_get(idx, true_branch), im.tuple_get(idx, false_branch)) - ) # todo: check if visit needed + ) # todo: check if visit needed - if ( - self.flags & self.Flag.LETIFY_MAKE_TUPLE_ELEMENTS - and node.fun == ir.SymRef(id="make_tuple") + if self.flags & self.Flag.LETIFY_MAKE_TUPLE_ELEMENTS and node.fun == ir.SymRef( + id="make_tuple" ): # `make_tuple(expr1, expr1)` # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` bound_vars: dict[str, ir.Expr] = {} new_args: list[ir.Expr] = [] for i, arg in enumerate(node.args): - if isinstance(node, ir.FunCall) and node.fun == im.ref( - "make_tuple") and not _is_trivial_make_tuple_call(node): + if ( + isinstance(node, ir.FunCall) + and node.fun == im.ref("make_tuple") + and not _is_trivial_make_tuple_call(node) + ): el_name = self._letify_make_tuple_uids.sequential_id() new_args.append(im.ref(el_name)) bound_vars[el_name] = arg @@ -232,8 +245,11 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: new_args.append(arg) if bound_vars: - return self.visit(im.let(*(el for item in bound_vars.items() for el in item))( - im.call(node.fun)(*new_args))) + return self.visit( + im.let(*(el for item in bound_vars.items() for el in item))( + im.call(node.fun)(*new_args) + ) + ) if self.flags & self.Flag.INLINE_TRIVIAL_MAKE_TUPLE and is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` @@ -250,7 +266,9 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: if is_if_call(arg): cond, true_branch, false_branch = arg.args new_true_branch = self.visit(_with_altered_arg(node, i, true_branch), **kwargs) - new_false_branch = self.visit(_with_altered_arg(node, i, false_branch), **kwargs) + new_false_branch = self.visit( + _with_altered_arg(node, i, false_branch), **kwargs + ) return im.if_(cond, new_true_branch, new_false_branch) if self.flags & self.Flag.PROPAGATE_NESTED_LET and is_let(node): @@ -268,12 +286,20 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: else: inner_vars[arg_sym] = arg if outer_vars: - node = self.visit(nlet(tuple(outer_vars.items()))(nlet(tuple(inner_vars.items()))(original_inner_expr))) + node = self.visit( + nlet(tuple(outer_vars.items()))( + nlet(tuple(inner_vars.items()))(original_inner_expr) + ) + ) - if self.flags & self.Flag.INLINE_TRIVIAL_LET and is_let(node) and isinstance(node.fun.expr, ir.SymRef): + if ( + self.flags & self.Flag.INLINE_TRIVIAL_LET + and is_let(node) + and isinstance(node.fun.expr, ir.SymRef) + ): # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): if node.fun.expr == im.ref(arg_sym.id): return arg - return node \ No newline at end of file + return node diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 4faafa96f5..d9d3d18213 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -24,9 +24,9 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next.iterator import ir, type_inference from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.pretty_printer import PrettyPrinter from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.cse import extract_subexpression from gt4py.next.iterator.transforms.eta_reduction import EtaReduction from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 07cd1aadae..d7578a62e0 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -111,7 +111,7 @@ def apply_common_transforms( # to limit number of times global type inference is executed, only in the last iterations. use_global_type_inference=inlined == ir, ) - inlined = PropagateDeref.apply(inlined) # todo: document + inlined = PropagateDeref.apply(inlined) # todo: document if inlined == ir: break diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index b01c49b4d9..fe0578f0cf 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -57,7 +57,11 @@ def visit_FunCall(self, node: ir.FunCall): ), args=lambda_args, ) - elif node.fun == im.ref("deref") and isinstance(node.args[0], ir.FunCall) and node.args[0].fun == im.ref("if_"): + elif ( + node.fun == im.ref("deref") + and isinstance(node.args[0], ir.FunCall) + and node.args[0].fun == im.ref("if_") + ): cond, true_branch, false_branch = node.args[0].args return im.if_(cond, im.deref(true_branch), im.deref(false_branch)) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 0ec42d9910..eb5fb77033 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -26,7 +26,7 @@ from gt4py.next.otf import languages, recipes, stages, step_types, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import cache, compiler -from gt4py.next.otf.compilation.build_systems import compiledb, cmake +from gt4py.next.otf.compilation.build_systems import cmake, compiledb from gt4py.next.program_processors import otf_compile_executor from gt4py.next.program_processors.codegens.gtfn import gtfn_module from gt4py.next.type_system.type_translation import from_value @@ -130,8 +130,7 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: ) GTFN_DEFAULT_COMPILE_STEP: step_types.CompilationStep = compiler.Compiler( - cache_strategy=cache.Strategy.PERSISTENT, - builder_factory=cmake.CMakeFactory() + cache_strategy=cache.Strategy.PERSISTENT, builder_factory=cmake.CMakeFactory() ) diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index fafa7b53bb..fedaf7c666 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -115,7 +115,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): COMMON_SKIP_TEST_LIST = [ (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), - #(USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), + # (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index e42c08194d..c6a92166e9 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -62,7 +62,9 @@ def test_incompatible_size_make_tuple_tuple_get(): def test_merged_with_smaller_outer_size_make_tuple_tuple_get(): testee = im.make_tuple(im.tuple_get(0, im.make_tuple("first", "second"))) - actual = CollapseTuple.apply(testee, ignore_tuple_size=True, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET) + actual = CollapseTuple.apply( + testee, ignore_tuple_size=True, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET + ) assert actual == im.make_tuple("first", "second") @@ -72,15 +74,14 @@ def test_simple_tuple_get_make_tuple(): actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE) assert expected == actual + def test_propagate_tuple_get(): expected = im.let("el1", 1, "el2", 2)(im.tuple_get(0, im.make_tuple("el1", "el2"))) testee = im.tuple_get(0, im.let("el1", 1, "el2", 2)(im.make_tuple("el1", "el2"))) - actual = CollapseTuple.apply( - testee, - flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET - ) + actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET) assert expected == actual + def test_letify_make_tuple_elements(): opaque_call = im.call("opaque")() testee = im.make_tuple(opaque_call, opaque_call) @@ -88,49 +89,32 @@ def test_letify_make_tuple_elements(): im.make_tuple("_tuple_el_1", "_tuple_el_2") ) - actual = CollapseTuple.apply( - testee, - flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS - ) + actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS) assert actual == expected def test_letify_make_tuple_with_trivial_elements(): - testee = im.let("a", 1, "b", 2)( - im.make_tuple("a", "b") - ) + testee = im.let("a", 1, "b", 2)(im.make_tuple("a", "b")) expected = testee # did nothing - actual = CollapseTuple.apply( - testee, - flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS - ) + actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS) assert actual == expected + def test_inline_trivial_make_tuple(): testee = im.let("tup", im.make_tuple("a", "b"))("tup") expected = im.make_tuple("a", "b") - actual = CollapseTuple.apply( - testee, - flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE - ) + actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE) assert actual == expected + def test_propagate_to_if_on_tuples(): - testee = im.tuple_get(0, - im.if_("cond", - im.make_tuple(1, 2), - im.make_tuple(3, 4))) - expected = im.if_("cond", - im.tuple_get(0, - im.make_tuple(1, 2)), - im.tuple_get(0, - im.make_tuple(3, 4))) - actual = CollapseTuple.apply( - testee, - flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES + testee = im.tuple_get(0, im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4))) + expected = im.if_( + "cond", im.tuple_get(0, im.make_tuple(1, 2)), im.tuple_get(0, im.make_tuple(3, 4)) ) + actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES) assert actual == expected @@ -138,14 +122,13 @@ def test_propagate_to_if_on_tuples_with_let(): testee = im.let("val", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( im.tuple_get(0, "val") ) - expected = im.if_("cond", - im.tuple_get(0, - im.make_tuple(1, 2)), - im.tuple_get(0, - im.make_tuple(3, 4))) + expected = im.if_( + "cond", im.tuple_get(0, im.make_tuple(1, 2)), im.tuple_get(0, im.make_tuple(3, 4)) + ) actual = CollapseTuple.apply( testee, - flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS + flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES + | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, ) assert actual == expected @@ -153,10 +136,7 @@ def test_propagate_to_if_on_tuples_with_let(): def test_propagate_nested_lift(): testee = im.let("a", im.let("b", 1)("a_val"))("a") expected = im.let("b", 1)(im.let("a", "a_val")("a")) - actual = CollapseTuple.apply( - testee, - flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET - ) + actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET) assert actual == expected @@ -168,6 +148,6 @@ def test_collapse_complicated_(): expected = im.if_("cond", 1, 3) actual = CollapseTuple.apply( testee, - #flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES + # flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES ) - assert actual == expected \ No newline at end of file + assert actual == expected From 8772906eb1976c05247e452e9ba967dbc7acbf56 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 29 Nov 2023 16:34:10 +0100 Subject: [PATCH 09/21] Fix tests --- src/gt4py/next/iterator/transforms/cse.py | 2 +- src/gt4py/next/iterator/transforms/symbol_ref_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 672e23c5e7..cc70e11413 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -233,7 +233,7 @@ def extract_subexpression( Examples: Default case for `(x+y) + ((x+y)+z)`: - >>> import gt4py.next.iterator.ir_makers as im + >>> import gt4py.next.iterator.ir_utils.ir_makers as im >>> from gt4py.eve.utils import UIDGenerator >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> predicate = lambda subexpr, num_occurences: num_occurences > 1 diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 1f604d62b9..1c587fb9d6 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -36,7 +36,7 @@ def apply( Count references to given or all symbols in scope. Examples: - >>> import gt4py.next.iterator.ir_makers as im + >>> import gt4py.next.iterator.ir_utils.ir_makers as im >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> CountSymbolRefs.apply(expr) {'x': 2, 'y': 2, 'z': 1} From ace2dc06bedf933232eb87ae2550bd642b28372b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 29 Nov 2023 16:49:23 +0100 Subject: [PATCH 10/21] Fix tests --- .../{is_equal_value_heuristics.py => misc.py} | 0 .../next/iterator/transforms/collapse_tuple.py | 18 +++++++++++------- .../transforms_tests/test_collapse_tuple.py | 5 +++-- 3 files changed, 14 insertions(+), 9 deletions(-) rename src/gt4py/next/iterator/ir_utils/{is_equal_value_heuristics.py => misc.py} (100%) diff --git a/src/gt4py/next/iterator/ir_utils/is_equal_value_heuristics.py b/src/gt4py/next/iterator/ir_utils/misc.py similarity index 100% rename from src/gt4py/next/iterator/ir_utils/is_equal_value_heuristics.py rename to src/gt4py/next/iterator/ir_utils/misc.py diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 0800ae708c..338c78c16d 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -19,7 +19,7 @@ from gt4py import eve from gt4py.next import type_inference from gt4py.next.iterator import ir, type_inference as it_type_inference -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda @@ -90,13 +90,15 @@ class Flag(enum.IntEnum): #: `{1, 2}` -> `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` LETIFY_MAKE_TUPLE_ELEMENTS = 8 #: TODO - INLINE_TRIVIAL_MAKE_TUPLE = 16 + REMOVE_LETIFIED_MAKE_TUPLE_ELEMENTS = 16 #: TODO - PROPAGATE_TO_IF_ON_TUPLES = 32 + INLINE_TRIVIAL_MAKE_TUPLE = 32 #: TODO - PROPAGATE_NESTED_LET = 64 + PROPAGATE_TO_IF_ON_TUPLES = 64 #: TODO - INLINE_TRIVIAL_LET = 128 + PROPAGATE_NESTED_LET = 128 + #: TODO + INLINE_TRIVIAL_LET = 256 ignore_tuple_size: bool use_global_type_inference: bool @@ -105,6 +107,7 @@ class Flag(enum.IntEnum): | Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE | Flag.PROPAGATE_TUPLE_GET | Flag.LETIFY_MAKE_TUPLE_ELEMENTS + | Flag.REMOVE_LETIFIED_MAKE_TUPLE_ELEMENTS | Flag.INLINE_TRIVIAL_MAKE_TUPLE | Flag.PROPAGATE_TO_IF_ON_TUPLES | Flag.PROPAGATE_NESTED_LET @@ -151,7 +154,7 @@ def apply( # as otherwise two equal expressions containing a tuple will not be equal anymore # and the CSE pass can not remove them. # TODO: test case for `scan(lambda carry: {1, 2})` (see solve_nonhydro_stencil_52_like_z_q_tup) - if flags & cls.Flag.LETIFY_MAKE_TUPLE_ELEMENTS: + if flags & cls.Flag.REMOVE_LETIFIED_MAKE_TUPLE_ELEMENTS: new_node = InlineLambdas.apply( new_node, opcount_preserving=True, force_inline_lambda_args=False ) @@ -177,7 +180,8 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: assert isinstance(v, ir.FunCall) assert isinstance(v.args[0], ir.Literal) if not ( - int(v.args[0].value) == i and _is_equal_value_heuristics(v.args[1], first_expr) + int(v.args[0].value) == i + and ir_misc.is_equal_value_heuristics(v.args[1], first_expr) ): # tuple argument differs, just continue with the rest of the tree return self.generic_visit(node) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index c6a92166e9..c312b28906 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -119,7 +119,7 @@ def test_propagate_to_if_on_tuples(): def test_propagate_to_if_on_tuples_with_let(): - testee = im.let("val", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( + testee = im.let("val", im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4)))( im.tuple_get(0, "val") ) expected = im.if_( @@ -128,7 +128,8 @@ def test_propagate_to_if_on_tuples_with_let(): actual = CollapseTuple.apply( testee, flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES - | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS + | CollapseTuple.Flag.REMOVE_LETIFIED_MAKE_TUPLE_ELEMENTS, ) assert actual == expected From 8b7a6d7e24ebdd6cbaca36709adde6d73351dc01 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 29 Nov 2023 17:59:08 +0100 Subject: [PATCH 11/21] Fix tests --- .../ir_utils/common_pattern_matcher.py | 4 +-- src/gt4py/next/iterator/ir_utils/ir_makers.py | 32 +++++++++++++++---- .../iterator/transforms/collapse_tuple.py | 22 ++++--------- .../transforms_tests/test_collapse_tuple.py | 8 ++--- .../transforms_tests/test_cse.py | 17 +++++----- 5 files changed, 45 insertions(+), 38 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 661086aa51..1b1cc494d3 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -28,10 +28,10 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: def is_let(node: itir.Node) -> bool: - """Match expression of the form `(λ(...) → ...)(...)`""" + """Match expression of the form `(λ(...) → ...)(...)`.""" return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) def is_if_call(node: itir.Expr): - """Match expression of the form `if_(cond, true_branch, false_branch)`""" + """Match expression of the form `if_(cond, true_branch, false_branch)`.""" return isinstance(node, itir.FunCall) and node.fun == im.ref("if_") diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index f7086ada0c..507aa4ac27 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -12,7 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Callable, Union +import typing +from typing import Callable, Iterable, Union from gt4py._core import definitions as core_defs from gt4py.next.iterator import ir as itir @@ -242,16 +243,33 @@ class let: -------- >>> str(let("a", "b")("a")) # doctest: +ELLIPSIS '(λ(a) → a)(b)' - >>> str(let("a", 1, - ... "b", 2 + >>> str(let(("a", 1), + ... ("b", 2) ... )(plus("a", "b"))) '(λ(a, b) → a + b)(1, 2)' """ - def __init__(self, *vars_and_values): - assert len(vars_and_values) % 2 == 0 - self.vars = vars_and_values[0::2] - self.init_forms = vars_and_values[1::2] + @typing.overload + def __init__(self, var: str | itir.Sym, init_form: itir.Expr): + ... + + @typing.overload + def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr]]): + ... + + def __init__(self, *args): + if all(isinstance(arg, tuple) and len(arg) == 2 for arg in args): + assert isinstance(args, tuple) + assert all(isinstance(arg, tuple) and len(arg) == 2 for arg in args) + self.vars = [var for var, _ in args] + self.init_forms = [init_form for _, init_form in args] + elif len(args) == 2: + self.vars = [args[0]] + self.init_forms = [args[1]] + else: + raise TypeError( + "Invalid arguments. Expected a variable name and an init form or a list thereof." + ) def __call__(self, form): return call(lambda_(*self.vars)(form))(*self.init_forms) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 338c78c16d..9733f87009 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -15,8 +15,8 @@ import enum from typing import Optional -import gt4py.eve.utils from gt4py import eve +from gt4py.eve import utils as eve_utils from gt4py.next import type_inference from gt4py.next.iterator import ir, type_inference as it_type_inference from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc @@ -67,10 +67,6 @@ def _is_trivial_make_tuple_call(node: ir.Expr): return True -def nlet(bindings: list[tuple[ir.Sym, str], ir.Expr]): - return im.let(*[el for tup in bindings for el in tup]) - - @dataclasses.dataclass(frozen=True) class CollapseTuple(eve.NodeTranslator): """ @@ -118,8 +114,8 @@ class Flag(enum.IntEnum): # we use one UID generator per instance such that the generated ids are # stable across multiple runs (required for caching to properly work) - _letify_make_tuple_uids: eve.utils.UIDGenerator = dataclasses.field( - init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="_tuple_el") + _letify_make_tuple_uids: eve_utils.UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_tuple_el") ) _node_types: Optional[dict[int, type_inference.Type]] = None @@ -236,7 +232,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` bound_vars: dict[str, ir.Expr] = {} new_args: list[ir.Expr] = [] - for i, arg in enumerate(node.args): + for arg in node.args: if ( isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple") @@ -249,11 +245,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: new_args.append(arg) if bound_vars: - return self.visit( - im.let(*(el for item in bound_vars.items() for el in item))( - im.call(node.fun)(*new_args) - ) - ) + return self.visit(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) if self.flags & self.Flag.INLINE_TRIVIAL_MAKE_TUPLE and is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` @@ -291,9 +283,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: inner_vars[arg_sym] = arg if outer_vars: node = self.visit( - nlet(tuple(outer_vars.items()))( - nlet(tuple(inner_vars.items()))(original_inner_expr) - ) + im.let(*outer_vars.items())(im.let(*inner_vars.items())(original_inner_expr)) ) if ( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index c312b28906..dd4406c86d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -76,8 +76,8 @@ def test_simple_tuple_get_make_tuple(): def test_propagate_tuple_get(): - expected = im.let("el1", 1, "el2", 2)(im.tuple_get(0, im.make_tuple("el1", "el2"))) - testee = im.tuple_get(0, im.let("el1", 1, "el2", 2)(im.make_tuple("el1", "el2"))) + expected = im.let(("el1", 1), ("el2", 2))(im.tuple_get(0, im.make_tuple("el1", "el2"))) + testee = im.tuple_get(0, im.let(("el1", 1), ("el2", 2))(im.make_tuple("el1", "el2"))) actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET) assert expected == actual @@ -85,7 +85,7 @@ def test_propagate_tuple_get(): def test_letify_make_tuple_elements(): opaque_call = im.call("opaque")() testee = im.make_tuple(opaque_call, opaque_call) - expected = im.let("_tuple_el_1", opaque_call, "_tuple_el_2", opaque_call)( + expected = im.let(("_tuple_el_1", opaque_call), ("_tuple_el_2", opaque_call))( im.make_tuple("_tuple_el_1", "_tuple_el_2") ) @@ -94,7 +94,7 @@ def test_letify_make_tuple_elements(): def test_letify_make_tuple_with_trivial_elements(): - testee = im.let("a", 1, "b", 2)(im.make_tuple("a", "b")) + testee = im.let(("a", 1), ("b", 2))(im.make_tuple("a", "b")) expected = testee # did nothing actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 065095e1c2..fb7720f4d7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -213,15 +213,14 @@ def is_let(node: ir.Expr): testee = im.plus( im.let( - "c", - im.let( - "a", - 1, - "b", - 2, - )(im.plus("a", "b")), - "d", - 3, + ( + "c", + im.let( + ("a", 1), + ("b", 2), + )(im.plus("a", "b")), + ), + ("d", 3), )(im.plus("c", "d")), 4, ) From 61133120d7c892d85523f8b89aa76bfa47faf42f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 29 Nov 2023 18:02:57 +0100 Subject: [PATCH 12/21] Fix tests --- src/gt4py/next/iterator/ir_utils/misc.py | 4 ++-- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 4e4893721c..fcbd7962b8 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -57,13 +57,13 @@ def visit_SymRef(self, node: itir.SymRef, *, sym_map: dict[str, str]): return im.ref(sym_map[node.id]) if node.id in sym_map else node -def is_equal_value_heuristics(a: itir.Expr, b: itir.Expr): +def is_provable_equal(a: itir.Expr, b: itir.Expr): """ Return true if, bot not only if, two expression (with equal scope) have the same value. >>> testee1 = im.lambda_("a")(im.plus("a", "b")) >>> testee2 = im.lambda_("c")(im.plus("c", "b")) - >>> assert is_equal_value_heuristics(testee1, testee2) + >>> assert is_provable_equal(testee1, testee2) """ return a == b or ( CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 9733f87009..0790c4b361 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -177,7 +177,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: assert isinstance(v.args[0], ir.Literal) if not ( int(v.args[0].value) == i - and ir_misc.is_equal_value_heuristics(v.args[1], first_expr) + and ir_misc.is_provable_equal(v.args[1], first_expr) ): # tuple argument differs, just continue with the rest of the tree return self.generic_visit(node) From 3367c2817d2c3d3739dbf0aae0ed6b2ff6722564 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 5 Jan 2024 00:25:15 +0100 Subject: [PATCH 13/21] Small fix --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 8bbd14e46e..40f70e6ae9 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -143,11 +143,17 @@ def apply( if use_global_type_inference: it_type_inference.infer_all(node, save_to_annex=True) + # TODO(tehrengruber): We don't want neither opcount preserving nor unconditionally inlining, + # but only force inline of lambda args. + new_node = InlineLambdas.apply( + node, opcount_preserving=True, force_inline_lambda_args=True + ) + new_node = cls( ignore_tuple_size=ignore_tuple_size, use_global_type_inference=use_global_type_inference, flags=flags, - ).visit(node) + ).visit(new_node) # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important # as otherwise two equal expressions containing a tuple will not be equal anymore From 3a2a0075e9473d475624398b10f0ced34285b459 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 6 Feb 2024 10:26:20 +0100 Subject: [PATCH 14/21] Cleanup --- .../ir_utils/common_pattern_matcher.py | 4 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 8 ++-- .../iterator/transforms/collapse_tuple.py | 45 ++++++++++--------- .../iterator/transforms/inline_lambdas.py | 2 - 4 files changed, 28 insertions(+), 31 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 1b1cc494d3..a4b074a4b6 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -27,11 +27,11 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: ) -def is_let(node: itir.Node) -> bool: +def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: """Match expression of the form `(λ(...) → ...)(...)`.""" return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) -def is_if_call(node: itir.Expr): +def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]: """Match expression of the form `if_(cond, true_branch, false_branch)`.""" return isinstance(node, itir.FunCall) and node.fun == im.ref("if_") diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index a912923d31..4337e8512a 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -250,12 +250,10 @@ class let: """ @typing.overload - def __init__(self, var: str | itir.Sym, init_form: itir.Expr): - ... + def __init__(self, var: str | itir.Sym, init_form: itir.Expr): ... @typing.overload - def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr]]): - ... + def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr]]): ... def __init__(self, *args): if all(isinstance(arg, tuple) and len(arg) == 2 for arg in args): @@ -268,7 +266,7 @@ def __init__(self, *args): self.init_forms = [args[1]] else: raise TypeError( - "Invalid arguments. Expected a variable name and an init form or a list thereof." + "Invalid arguments: expected a variable name and an init form or a list thereof." ) def __call__(self, form): diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 8a8d6fc7e1..bf801bd084 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -60,6 +60,7 @@ def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): def _is_trivial_make_tuple_call(node: ir.Expr): + """Given an `Expr` return if it is a `make_tuple` call with all elements `SymRef`s or `Literal`s.""" if not (isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple")): return False if not all( @@ -88,15 +89,17 @@ class Flag(enum.IntEnum): PROPAGATE_TUPLE_GET = 4 #: `{1, 2}` -> `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` LETIFY_MAKE_TUPLE_ELEMENTS = 8 - #: TODO + #: Inverse of LETIFY_MAKE_TUPLE_ELEMENTS run after all other transformations + #: `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` -> {1, 2}` REMOVE_LETIFIED_MAKE_TUPLE_ELEMENTS = 16 - #: TODO + #: `let(tup, {trivial_expr1, trivial_expr2})(foo(tup))` + #: -> `foo({trivial_expr1, trivial_expr2})` INLINE_TRIVIAL_MAKE_TUPLE = 32 - #: TODO + #: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` PROPAGATE_TO_IF_ON_TUPLES = 64 - #: TODO + #: `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` PROPAGATE_NESTED_LET = 128 - #: TODO + #: `let(a, 1)(a)` -> `1` INLINE_TRIVIAL_LET = 256 ignore_tuple_size: bool @@ -143,15 +146,11 @@ def apply( if use_global_type_inference: it_type_inference.infer_all(node, save_to_annex=True) - # TODO(tehrengruber): We don't want neither opcount preserving nor unconditionally inlining, - # but only force inline of lambda args. - new_node = InlineLambdas.apply(node, opcount_preserving=True, force_inline_lambda_args=True) - new_node = cls( ignore_tuple_size=ignore_tuple_size, use_global_type_inference=use_global_type_inference, flags=flags, - ).visit(new_node) + ).visit(node) # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important # as otherwise two equal expressions containing a tuple will not be equal anymore @@ -220,15 +219,17 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: if is_let(node.args[1]): idx, let_expr = node.args return self.visit( - im.call(im.lambda_(*let_expr.fun.params)(im.tuple_get(idx, let_expr.fun.expr)))( - *let_expr.args + im.call( + im.lambda_(*let_expr.fun.params)(im.tuple_get(idx.value, let_expr.fun.expr)) # type: ignore[attr-defined] # ensured by is_let + )( + *let_expr.args # type: ignore[attr-defined] # ensured by is_let ) ) elif isinstance(node.args[1], ir.FunCall) and node.args[1].fun == im.ref("if_"): idx = node.args[0] cond, true_branch, false_branch = node.args[1].args return self.visit( - im.if_(cond, im.tuple_get(idx, true_branch), im.tuple_get(idx, false_branch)) + im.if_(cond, im.tuple_get(idx.value, true_branch), im.tuple_get(idx.value, false_branch)) ) # todo: check if visit needed if self.flags & self.Flag.LETIFY_MAKE_TUPLE_ELEMENTS and node.fun == ir.SymRef( @@ -251,7 +252,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: new_args.append(arg) if bound_vars: - return self.visit(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) + return self.visit(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) # type: ignore[arg-type] # mypy not smart enough if self.flags & self.Flag.INLINE_TRIVIAL_MAKE_TUPLE and is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` @@ -277,29 +278,29 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} inner_vars = {} - original_inner_expr = node.fun.expr - for arg_sym, arg in zip(node.fun.params, node.args): + original_inner_expr = node.fun.expr # type: ignore[attr-defined] # ensured by is_let + for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let assert arg_sym not in inner_vars # TODO: fix collisions if is_let(arg): - for sym, val in zip(arg.fun.params, arg.args): + for sym, val in zip(arg.fun.params, arg.args): # type: ignore[attr-defined] # ensured by is_let assert sym not in outer_vars # TODO: fix collisions outer_vars[sym] = val - inner_vars[arg_sym] = arg.fun.expr + inner_vars[arg_sym] = arg.fun.expr # type: ignore[attr-defined] # ensured by is_let else: inner_vars[arg_sym] = arg if outer_vars: node = self.visit( - im.let(*outer_vars.items())(im.let(*inner_vars.items())(original_inner_expr)) + im.let(*outer_vars.items())(im.let(*inner_vars.items())(original_inner_expr)) # type: ignore[arg-type] # mypy not smart enough ) if ( self.flags & self.Flag.INLINE_TRIVIAL_LET and is_let(node) - and isinstance(node.fun.expr, ir.SymRef) + and isinstance(node.fun.expr, ir.SymRef) # type: ignore[attr-defined] # ensured by is_let ): # `let(a, 1)(a)` -> `1` - for arg_sym, arg in zip(node.fun.params, node.args): - if node.fun.expr == im.ref(arg_sym.id): + for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let + if node.fun.expr == im.ref(arg_sym.id): # type: ignore[attr-defined] # ensured by is_let return arg return node diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 2831287d82..0b89fe6d98 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -135,8 +135,6 @@ class InlineLambdas(PreserveLocationVisitor, NodeTranslator): force_inline_trivial_lift_args: bool - force_inline_lambda_args: bool - @classmethod def apply( cls, From d915af52feda5945dccf2f6b642ffdfdf62ff0b4 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 6 Feb 2024 13:10:15 +0100 Subject: [PATCH 15/21] Cleanup --- .../iterator/transforms/collapse_tuple.py | 131 +++++++++++------- tests/next_tests/definitions.py | 2 +- 2 files changed, 83 insertions(+), 50 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index bf801bd084..2045ed42d4 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -163,16 +163,40 @@ def apply( return new_node - def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: - node = self.generic_visit(node, **kwargs) + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: + node = self.generic_visit(node) + return self.fp_transform(node) + + def fp_transform( + self, node: ir.Node + ) -> ir.Node: # todo: pass what transformations to do (one or all) + while True: + new_node = self.transform(node) + if new_node is None: + break + assert new_node != node + node = new_node + return node - if ( - self.flags & self.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET - and node.fun == ir.SymRef(id="make_tuple") - and all( - isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get") - for arg in node.args - ) + def transform(self, node: ir.Node) -> Optional[ir.Node]: + if not isinstance(node, ir.FunCall): + return None + + for transformation in self.Flag: + if transformation == self.Flag.REMOVE_LETIFIED_MAKE_TUPLE_ELEMENTS: + continue + if self.flags & transformation: + # todo: remove flags and make it a list + method = getattr(self, f"transform_{transformation.name.lower()}") + result = method(node) + if result is not None: + return result + return None + + def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: + if node.fun == ir.SymRef(id="make_tuple") and all( + isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get") + for arg in node.args ): # `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` assert isinstance(node.args[0], ir.FunCall) @@ -185,16 +209,17 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: int(v.args[0].value) == i and ir_misc.is_provable_equal(v.args[1], first_expr) ): # tuple argument differs, just continue with the rest of the tree - return self.generic_visit(node) + return None if self.ignore_tuple_size or _get_tuple_size( first_expr, self.use_global_type_inference ) == len(node.args): return first_expr + return None + def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: if ( - self.flags & self.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE - and node.fun == ir.SymRef(id="tuple_get") + node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[1], ir.FunCall) and node.args[1].fun == ir.SymRef(id="make_tuple") and isinstance(node.args[0], ir.Literal) @@ -207,34 +232,36 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: make_tuple_call.args ), f"Index {idx} is out of bounds for tuple of size {len(make_tuple_call.args)}" return node.args[1].args[idx] + return None - if ( - self.flags & self.Flag.PROPAGATE_TUPLE_GET - and node.fun == ir.SymRef(id="tuple_get") - and isinstance( - node.args[0], ir.Literal - ) # TODO: extend to general symbols as long as the tail call in the let does not capture - ): + def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: + if node.fun == ir.SymRef(id="tuple_get") and isinstance( + node.args[0], ir.Literal + ): # TODO: extend to general symbols as long as the tail call in the let does not capture # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` if is_let(node.args[1]): idx, let_expr = node.args - return self.visit( - im.call( - im.lambda_(*let_expr.fun.params)(im.tuple_get(idx.value, let_expr.fun.expr)) # type: ignore[attr-defined] # ensured by is_let - )( - *let_expr.args # type: ignore[attr-defined] # ensured by is_let + return im.call( + im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let + self.fp_transform(im.tuple_get(idx.value, let_expr.fun.expr)) # type: ignore[attr-defined] # ensured by is_let ) - ) + )( + *let_expr.args + ) # type: ignore[attr-defined] # ensured by is_let elif isinstance(node.args[1], ir.FunCall) and node.args[1].fun == im.ref("if_"): idx = node.args[0] cond, true_branch, false_branch = node.args[1].args - return self.visit( - im.if_(cond, im.tuple_get(idx.value, true_branch), im.tuple_get(idx.value, false_branch)) + return im.if_( + cond, + self.fp_transform( + im.tuple_get(idx.value, true_branch) + ), # todo: call transformation directly + self.fp_transform(im.tuple_get(idx.value, false_branch)), ) # todo: check if visit needed + return None - if self.flags & self.Flag.LETIFY_MAKE_TUPLE_ELEMENTS and node.fun == ir.SymRef( - id="make_tuple" - ): + def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir.Node]: + if node.fun == ir.SymRef(id="make_tuple"): # `make_tuple(expr1, expr1)` # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` bound_vars: dict[str, ir.Expr] = {} @@ -252,55 +279,61 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: new_args.append(arg) if bound_vars: - return self.visit(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) # type: ignore[arg-type] # mypy not smart enough + return self.fp_transform(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) # type: ignore[arg-type] # mypy not smart enough + return None - if self.flags & self.Flag.INLINE_TRIVIAL_MAKE_TUPLE and is_let(node): + def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: + if is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] if any(eligible_params): return self.visit(inline_lambda(node, eligible_params=eligible_params)) + return None - if self.flags & self.Flag.PROPAGATE_TO_IF_ON_TUPLES and not node.fun == im.ref("if_"): + def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.Node]: + if not node.fun == im.ref("if_"): # TODO(tehrengruber): This significantly increases the size of the tree. Revisit. # TODO(tehrengruber): Only inline if type of branch value is a tuple. + # Examples: # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` + # `let (b, if cond then {1, 2} else {3, 4})) b[0]` + # -> `if cond then let(b, {1, 2})(b[0]) else let(b, {3, 4})(b[0])` for i, arg in enumerate(node.args): if is_if_call(arg): cond, true_branch, false_branch = arg.args - new_true_branch = self.visit(_with_altered_arg(node, i, true_branch), **kwargs) - new_false_branch = self.visit( - _with_altered_arg(node, i, false_branch), **kwargs - ) + new_true_branch = self.fp_transform(_with_altered_arg(node, i, true_branch)) + new_false_branch = self.fp_transform(_with_altered_arg(node, i, false_branch)) return im.if_(cond, new_true_branch, new_false_branch) + return None - if self.flags & self.Flag.PROPAGATE_NESTED_LET and is_let(node): + def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: + if is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} inner_vars = {} original_inner_expr = node.fun.expr # type: ignore[attr-defined] # ensured by is_let for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let - assert arg_sym not in inner_vars # TODO: fix collisions + assert arg_sym not in inner_vars # TODO(tehrengruber): fix collisions if is_let(arg): for sym, val in zip(arg.fun.params, arg.args): # type: ignore[attr-defined] # ensured by is_let - assert sym not in outer_vars # TODO: fix collisions + assert sym not in outer_vars # TODO(tehrengruber): fix collisions outer_vars[sym] = val inner_vars[arg_sym] = arg.fun.expr # type: ignore[attr-defined] # ensured by is_let else: inner_vars[arg_sym] = arg if outer_vars: - node = self.visit( - im.let(*outer_vars.items())(im.let(*inner_vars.items())(original_inner_expr)) # type: ignore[arg-type] # mypy not smart enough + return self.fp_transform( + im.let(*outer_vars.items())( # type: ignore[arg-type] # mypy not smart enough + self.fp_transform(im.let(*inner_vars.items())(original_inner_expr)) + ) ) + return None - if ( - self.flags & self.Flag.INLINE_TRIVIAL_LET - and is_let(node) - and isinstance(node.fun.expr, ir.SymRef) # type: ignore[attr-defined] # ensured by is_let - ): + def transform_inline_trivial_let(self, node: ir.FunCall) -> Optional[ir.Node]: + if is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let if node.fun.expr == im.ref(arg_sym.id): # type: ignore[attr-defined] # ensured by is_let return arg - - return node + return None diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index ac9cd789f3..56b220e0e9 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -148,7 +148,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): COMMON_SKIP_TEST_LIST = [ (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), - # (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), From fcdb8aecc33d6cd78e33a3e786e6880335eb5ab0 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 6 Feb 2024 13:14:45 +0100 Subject: [PATCH 16/21] Cleanup --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 2045ed42d4..0c441ccc54 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -257,7 +257,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: im.tuple_get(idx.value, true_branch) ), # todo: call transformation directly self.fp_transform(im.tuple_get(idx.value, false_branch)), - ) # todo: check if visit needed + ) return None def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir.Node]: From 796f3a76101e0644d9bec8f1f1361496a2114598 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 7 Feb 2024 11:24:20 +0100 Subject: [PATCH 17/21] Revert debug changes to caching --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 4 ++-- src/gt4py/next/program_processors/runners/gtfn.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 0c441ccc54..4b18ce72de 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -246,8 +246,8 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: self.fp_transform(im.tuple_get(idx.value, let_expr.fun.expr)) # type: ignore[attr-defined] # ensured by is_let ) )( - *let_expr.args - ) # type: ignore[attr-defined] # ensured by is_let + *let_expr.args # type: ignore[attr-defined] # ensured by is_let + ) elif isinstance(node.args[1], ir.FunCall) and node.args[1].fun == im.ref("if_"): idx = node.args[0] cond, true_branch, false_branch = node.args[1].args diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index b93bcb412e..baa45ddc0e 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -26,7 +26,7 @@ from gt4py.next.otf import languages, recipes, stages, step_types, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import cache, compiler -from gt4py.next.otf.compilation.build_systems import cmake, compiledb +from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors import otf_compile_executor from gt4py.next.program_processors.codegens.gtfn import gtfn_module from gt4py.next.type_system.type_translation import from_value @@ -130,7 +130,7 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: ) GTFN_DEFAULT_COMPILE_STEP: step_types.CompilationStep = compiler.Compiler( - cache_strategy=cache.Strategy.PERSISTENT, builder_factory=cmake.CMakeFactory() + cache_strategy=cache.Strategy.SESSION, builder_factory=compiledb.CompiledbFactory() ) From bdc92215b3eecf539fcd01d70c253f85b22b57eb Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 7 Feb 2024 11:41:54 +0100 Subject: [PATCH 18/21] Cleanup --- .../next/iterator/transforms/collapse_tuple.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 4b18ce72de..46428e7560 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -71,6 +71,11 @@ def _is_trivial_make_tuple_call(node: ir.Expr): return True +# TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first, +# transform each node until no transformations apply anymore, whenever a node is to be transformed +# go through all available transformation and apply them. However the final result here still +# reads a little convoluted and is also different to how we write other transformations. We +# should revisit the pattern here and try to find a more general mechanism. @dataclasses.dataclass(frozen=True) class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): """ @@ -80,6 +85,9 @@ class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): - `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` """ + # TODO(tehrengruber): This Flag machanism is a little low level. What we actually want + # is something like a pass manager, where for each pattern we have a corresponding + # transformation, etc. class Flag(enum.IntEnum): #: `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` COLLAPSE_MAKE_TUPLE_TUPLE_GET = 1 @@ -167,9 +175,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node: node = self.generic_visit(node) return self.fp_transform(node) - def fp_transform( - self, node: ir.Node - ) -> ir.Node: # todo: pass what transformations to do (one or all) + def fp_transform(self, node: ir.Node) -> ir.Node: while True: new_node = self.transform(node) if new_node is None: @@ -186,7 +192,6 @@ def transform(self, node: ir.Node) -> Optional[ir.Node]: if transformation == self.Flag.REMOVE_LETIFIED_MAKE_TUPLE_ELEMENTS: continue if self.flags & transformation: - # todo: remove flags and make it a list method = getattr(self, f"transform_{transformation.name.lower()}") result = method(node) if result is not None: @@ -253,9 +258,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: cond, true_branch, false_branch = node.args[1].args return im.if_( cond, - self.fp_transform( - im.tuple_get(idx.value, true_branch) - ), # todo: call transformation directly + self.fp_transform(im.tuple_get(idx.value, true_branch)), self.fp_transform(im.tuple_get(idx.value, false_branch)), ) return None From 9902d638082e50a1af3b9df488a88523cedfb4b9 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 7 Feb 2024 22:46:25 +0100 Subject: [PATCH 19/21] Address reviewer comments --- src/gt4py/next/iterator/ir_utils/misc.py | 9 +- .../iterator/transforms/collapse_tuple.py | 73 ++++++++-------- .../next/iterator/transforms/pass_manager.py | 5 +- .../transforms_tests/test_collapse_tuple.py | 84 ++++++++++++++----- 4 files changed, 110 insertions(+), 61 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index fcbd7962b8..5bf2307457 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -57,13 +57,16 @@ def visit_SymRef(self, node: itir.SymRef, *, sym_map: dict[str, str]): return im.ref(sym_map[node.id]) if node.id in sym_map else node -def is_provable_equal(a: itir.Expr, b: itir.Expr): +def is_equal(a: itir.Expr, b: itir.Expr): """ - Return true if, bot not only if, two expression (with equal scope) have the same value. + Return true if, but not only if, two expression (with equal scope) have the same value. + + Be aware that this function might return false even though the two expression have the same + value. >>> testee1 = im.lambda_("a")(im.plus("a", "b")) >>> testee2 = im.lambda_("c")(im.plus("c", "b")) - >>> assert is_provable_equal(testee1, testee2) + >>> assert is_equal(testee1, testee2) """ return a == b or ( CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 46428e7560..fff3df523e 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -13,6 +13,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses import enum +import functools +import operator from typing import Optional from gt4py import eve @@ -60,7 +62,7 @@ def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): def _is_trivial_make_tuple_call(node: ir.Expr): - """Given an `Expr` return if it is a `make_tuple` call with all elements `SymRef`s or `Literal`s.""" + """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" if not (isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple")): return False if not all( @@ -88,41 +90,32 @@ class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): # TODO(tehrengruber): This Flag machanism is a little low level. What we actually want # is something like a pass manager, where for each pattern we have a corresponding # transformation, etc. - class Flag(enum.IntEnum): + class Flag(enum.Flag): #: `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` - COLLAPSE_MAKE_TUPLE_TUPLE_GET = 1 + COLLAPSE_MAKE_TUPLE_TUPLE_GET = enum.auto() #: `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` - COLLAPSE_TUPLE_GET_MAKE_TUPLE = 2 + COLLAPSE_TUPLE_GET_MAKE_TUPLE = enum.auto() #: `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` - PROPAGATE_TUPLE_GET = 4 + PROPAGATE_TUPLE_GET = enum.auto() #: `{1, 2}` -> `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` - LETIFY_MAKE_TUPLE_ELEMENTS = 8 - #: Inverse of LETIFY_MAKE_TUPLE_ELEMENTS run after all other transformations - #: `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` -> {1, 2}` - REMOVE_LETIFIED_MAKE_TUPLE_ELEMENTS = 16 + LETIFY_MAKE_TUPLE_ELEMENTS = enum.auto() #: `let(tup, {trivial_expr1, trivial_expr2})(foo(tup))` #: -> `foo({trivial_expr1, trivial_expr2})` - INLINE_TRIVIAL_MAKE_TUPLE = 32 + INLINE_TRIVIAL_MAKE_TUPLE = enum.auto() #: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` - PROPAGATE_TO_IF_ON_TUPLES = 64 + PROPAGATE_TO_IF_ON_TUPLES = enum.auto() #: `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` - PROPAGATE_NESTED_LET = 128 + PROPAGATE_NESTED_LET = enum.auto() #: `let(a, 1)(a)` -> `1` - INLINE_TRIVIAL_LET = 256 + INLINE_TRIVIAL_LET = enum.auto() + + @classmethod + def all(self): # noqa: A003 # shadowing a python builtin + return functools.reduce(operator.or_, self.__members__.values()) ignore_tuple_size: bool use_global_type_inference: bool - flags: int = ( - Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET - | Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE - | Flag.PROPAGATE_TUPLE_GET - | Flag.LETIFY_MAKE_TUPLE_ELEMENTS - | Flag.REMOVE_LETIFIED_MAKE_TUPLE_ELEMENTS - | Flag.INLINE_TRIVIAL_MAKE_TUPLE - | Flag.PROPAGATE_TO_IF_ON_TUPLES - | Flag.PROPAGATE_NESTED_LET - | Flag.INLINE_TRIVIAL_LET - ) + flags: Flag = Flag.all() PRESERVED_ANNEX_ATTRS = ("type",) @@ -141,14 +134,23 @@ def apply( *, ignore_tuple_size: bool = False, use_global_type_inference: bool = False, + remove_letified_make_tuple_elements: bool = True, # manually passing flags is mostly for allowing separate testing of the modes flags=None, ) -> ir.Node: """ Simplifies `make_tuple`, `tuple_get` calls. - If `ignore_tuple_size`, apply the transformation even if length of the inner tuple - is greater than the length of the outer tuple. + Arguments: + node: The node to transform. + + Keyword arguments: + ignore_tuple_size: Apply the transformation even if length of the inner tuple is greater + than the length of the outer tuple. + use_global_type_inference: Run global type inference to determine tuple sizes. + remove_letified_make_tuple_elements: Run `InlineLambdas` as a post-processing step + to remove left-overs from `LETIFY_MAKE_TUPLE_ELEMENTS` transformation. + `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` -> {1, 2}` """ flags = flags or cls.flags if use_global_type_inference: @@ -163,8 +165,9 @@ def apply( # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important # as otherwise two equal expressions containing a tuple will not be equal anymore # and the CSE pass can not remove them. - # TODO: test case for `scan(lambda carry: {1, 2})` (see solve_nonhydro_stencil_52_like_z_q_tup) - if flags & cls.Flag.REMOVE_LETIFIED_MAKE_TUPLE_ELEMENTS: + # TODO(tehrengruber): test case for `scan(lambda carry: {1, 2})` + # (see solve_nonhydro_stencil_52_like_z_q_tup) + if remove_letified_make_tuple_elements: new_node = InlineLambdas.apply( new_node, opcount_preserving=True, force_inline_lambda_args=False ) @@ -188,9 +191,7 @@ def transform(self, node: ir.Node) -> Optional[ir.Node]: if not isinstance(node, ir.FunCall): return None - for transformation in self.Flag: - if transformation == self.Flag.REMOVE_LETIFIED_MAKE_TUPLE_ELEMENTS: - continue + for transformation in self.Flag.all(): if self.flags & transformation: method = getattr(self, f"transform_{transformation.name.lower()}") result = method(node) @@ -210,9 +211,7 @@ def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ for i, v in enumerate(node.args): assert isinstance(v, ir.FunCall) assert isinstance(v.args[0], ir.Literal) - if not ( - int(v.args[0].value) == i and ir_misc.is_provable_equal(v.args[1], first_expr) - ): + if not (int(v.args[0].value) == i and ir_misc.is_equal(v.args[1], first_expr)): # tuple argument differs, just continue with the rest of the tree return None @@ -240,9 +239,9 @@ def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ return None def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: - if node.fun == ir.SymRef(id="tuple_get") and isinstance( - node.args[0], ir.Literal - ): # TODO: extend to general symbols as long as the tail call in the let does not capture + if node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[0], ir.Literal): + # TODO(tehrengruber): extend to general symbols as long as the tail call in the let + # does not capture # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` if is_let(node.args[1]): idx, let_expr = node.args diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 5f66afb3aa..b6de9456bd 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -114,7 +114,10 @@ def apply_common_transforms( # to limit number of times global type inference is executed, only in the last iterations. use_global_type_inference=inlined == ir, ) - inlined = PropagateDeref.apply(inlined) # todo: document + # This pass is required such that a deref outside of a + # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the + # `tuple_get` is removed by the `CollapseTuple` pass. + inlined = PropagateDeref.apply(inlined) if inlined == ir: break diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index dd4406c86d..330f66bee5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -20,7 +20,11 @@ def test_simple_make_tuple_tuple_get(): tuple_of_size_2 = im.make_tuple("first", "second") testee = im.make_tuple(im.tuple_get(0, tuple_of_size_2), im.tuple_get(1, tuple_of_size_2)) - actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) expected = tuple_of_size_2 assert actual == expected @@ -32,7 +36,11 @@ def test_nested_make_tuple_tuple_get(): im.tuple_get(0, tup_of_size2_from_lambda), im.tuple_get(1, tup_of_size2_from_lambda) ) - actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) assert actual == tup_of_size2_from_lambda @@ -42,7 +50,11 @@ def test_different_tuples_make_tuple_tuple_get(): t1 = im.make_tuple("foo1", "bar1") testee = im.make_tuple(im.tuple_get(0, t0), im.tuple_get(1, t1)) - actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) assert actual == testee # did nothing @@ -50,13 +62,21 @@ def test_different_tuples_make_tuple_tuple_get(): def test_incompatible_order_make_tuple_tuple_get(): tuple_of_size_2 = im.make_tuple("first", "second") testee = im.make_tuple(im.tuple_get(1, tuple_of_size_2), im.tuple_get(0, tuple_of_size_2)) - actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) assert actual == testee # did nothing def test_incompatible_size_make_tuple_tuple_get(): testee = im.make_tuple(im.tuple_get(0, im.make_tuple("first", "second"))) - actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) assert actual == testee # did nothing @@ -71,14 +91,22 @@ def test_merged_with_smaller_outer_size_make_tuple_tuple_get(): def test_simple_tuple_get_make_tuple(): expected = im.ref("bar") testee = im.tuple_get(1, im.make_tuple("foo", expected)) - actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE, + ) assert expected == actual def test_propagate_tuple_get(): expected = im.let(("el1", 1), ("el2", 2))(im.tuple_get(0, im.make_tuple("el1", "el2"))) testee = im.tuple_get(0, im.let(("el1", 1), ("el2", 2))(im.make_tuple("el1", "el2"))) - actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET, + ) assert expected == actual @@ -89,7 +117,11 @@ def test_letify_make_tuple_elements(): im.make_tuple("_tuple_el_1", "_tuple_el_2") ) - actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + ) assert actual == expected @@ -97,7 +129,11 @@ def test_letify_make_tuple_with_trivial_elements(): testee = im.let(("a", 1), ("b", 2))(im.make_tuple("a", "b")) expected = testee # did nothing - actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + ) assert actual == expected @@ -105,7 +141,11 @@ def test_inline_trivial_make_tuple(): testee = im.let("tup", im.make_tuple("a", "b"))("tup") expected = im.make_tuple("a", "b") - actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE, + ) assert actual == expected @@ -114,7 +154,11 @@ def test_propagate_to_if_on_tuples(): expected = im.if_( "cond", im.tuple_get(0, im.make_tuple(1, 2)), im.tuple_get(0, im.make_tuple(3, 4)) ) - actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + ) assert actual == expected @@ -127,9 +171,9 @@ def test_propagate_to_if_on_tuples_with_let(): ) actual = CollapseTuple.apply( testee, + remove_letified_make_tuple_elements=True, flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES - | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS - | CollapseTuple.Flag.REMOVE_LETIFIED_MAKE_TUPLE_ELEMENTS, + | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, ) assert actual == expected @@ -137,18 +181,18 @@ def test_propagate_to_if_on_tuples_with_let(): def test_propagate_nested_lift(): testee = im.let("a", im.let("b", 1)("a_val"))("a") expected = im.let("b", 1)(im.let("a", "a_val")("a")) - actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET, + ) assert actual == expected -def test_collapse_complicated_(): - # TODO: fuse with test_propagate_to_if_on_tuples_with_let +def test_if_on_tuples_with_let(): testee = im.let("val", im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4)))( im.tuple_get(0, "val") ) expected = im.if_("cond", 1, 3) - actual = CollapseTuple.apply( - testee, - # flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES - ) + actual = CollapseTuple.apply(testee, remove_letified_make_tuple_elements=False) assert actual == expected From 650a93416defcbbbdb487f89999fcae6e80772c1 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 8 Feb 2024 10:29:30 +0100 Subject: [PATCH 20/21] Fix broken CI --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index fff3df523e..dea9475ffc 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -191,8 +191,9 @@ def transform(self, node: ir.Node) -> Optional[ir.Node]: if not isinstance(node, ir.FunCall): return None - for transformation in self.Flag.all(): + for transformation in self.Flag: if self.flags & transformation: + assert isinstance(transformation.name, str) method = getattr(self, f"transform_{transformation.name.lower()}") result = method(node) if result is not None: From e9f6fb181da391e2cee80dea583a6440a58384ad Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 13 Feb 2024 10:47:23 +0100 Subject: [PATCH 21/21] Address reviewer comments --- src/gt4py/next/iterator/ir_utils/misc.py | 8 +++++++- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 5bf2307457..4336649d06 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -59,7 +59,7 @@ def visit_SymRef(self, node: itir.SymRef, *, sym_map: dict[str, str]): def is_equal(a: itir.Expr, b: itir.Expr): """ - Return true if, but not only if, two expression (with equal scope) have the same value. + Return true if two expressions have provably equal values. Be aware that this function might return false even though the two expression have the same value. @@ -67,7 +67,13 @@ def is_equal(a: itir.Expr, b: itir.Expr): >>> testee1 = im.lambda_("a")(im.plus("a", "b")) >>> testee2 = im.lambda_("c")(im.plus("c", "b")) >>> assert is_equal(testee1, testee2) + + >>> testee1 = im.lambda_("a")(im.plus("a", "b")) + >>> testee2 = im.lambda_("c")(im.plus("c", "d")) + >>> assert not is_equal(testee1, testee2) """ + # TODO(tehrengruber): Extend this function cover more cases than just those with equal + # structure, e.g., by also canonicalization of the structure. return a == b or ( CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b) ) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index dea9475ffc..51daffed05 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -87,7 +87,7 @@ class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): - `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` """ - # TODO(tehrengruber): This Flag machanism is a little low level. What we actually want + # TODO(tehrengruber): This Flag mechanism is a little low level. What we actually want # is something like a pass manager, where for each pattern we have a corresponding # transformation, etc. class Flag(enum.Flag):