Skip to content

Commit

Permalink
feature[next]: Improve CollapseTuple pass (#1350)
Browse files Browse the repository at this point in the history
Significantly improves the collapse tuple pass preparing wider support for `if` statements.
  • Loading branch information
tehrengruber authored Feb 13, 2024
1 parent 1d305e1 commit 2970575
Show file tree
Hide file tree
Showing 8 changed files with 513 additions and 58 deletions.
11 changes: 11 additions & 0 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TypeGuard

from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im


def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
Expand All @@ -24,3 +25,13 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
and isinstance(arg.fun.fun, itir.SymRef)
and arg.fun.fun.id == "lift"
)


def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expression of the form `(λ(...) → ...)(...)`."""
return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda)


def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]:
"""Match expression of the form `if_(cond, true_branch, false_branch)`."""
return isinstance(node, itir.FunCall) and node.fun == im.ref("if_")
30 changes: 23 additions & 7 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from typing import Callable, Union
import typing
from typing import Callable, Iterable, Union

from gt4py._core import definitions as core_defs
from gt4py.next.iterator import ir as itir
Expand Down Expand Up @@ -242,16 +243,31 @@ class let:
--------
>>> str(let("a", "b")("a")) # doctest: +ELLIPSIS
'(λ(a) → a)(b)'
>>> str(let("a", 1,
... "b", 2
>>> str(let(("a", 1),
... ("b", 2)
... )(plus("a", "b")))
'(λ(a, b) → a + b)(1, 2)'
"""

def __init__(self, *vars_and_values):
assert len(vars_and_values) % 2 == 0
self.vars = vars_and_values[0::2]
self.init_forms = vars_and_values[1::2]
@typing.overload
def __init__(self, var: str | itir.Sym, init_form: itir.Expr): ...

@typing.overload
def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr]]): ...

def __init__(self, *args):
if all(isinstance(arg, tuple) and len(arg) == 2 for arg in args):
assert isinstance(args, tuple)
assert all(isinstance(arg, tuple) and len(arg) == 2 for arg in args)
self.vars = [var for var, _ in args]
self.init_forms = [init_form for _, init_form in args]
elif len(args) == 2:
self.vars = [args[0]]
self.init_forms = [args[1]]
else:
raise TypeError(
"Invalid arguments: expected a variable name and an init form or a list thereof."
)

def __call__(self, form):
return call(lambda_(*self.vars)(form))(*self.init_forms)
Expand Down
79 changes: 79 additions & 0 deletions src/gt4py/next/iterator/ir_utils/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

import dataclasses
from collections import ChainMap

from gt4py import eve
from gt4py.eve import utils as eve_utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im


@dataclasses.dataclass(frozen=True)
class CannonicalizeBoundSymbolNames(eve.NodeTranslator):
"""
Given an iterator expression cannonicalize all bound symbol names.
If two such expression are in the same scope and equal so are their values.
>>> testee1 = im.lambda_("a")(im.plus("a", "b"))
>>> cannonicalized_testee1 = CannonicalizeBoundSymbolNames.apply(testee1)
>>> str(cannonicalized_testee1)
'λ(_csym_1) → _csym_1 + b'
>>> testee2 = im.lambda_("c")(im.plus("c", "b"))
>>> cannonicalized_testee2 = CannonicalizeBoundSymbolNames.apply(testee2)
>>> assert cannonicalized_testee1 == cannonicalized_testee2
"""

_uids: eve_utils.UIDGenerator = dataclasses.field(
init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_csym")
)

@classmethod
def apply(cls, node: itir.Expr):
return cls().visit(node, sym_map=ChainMap({}))

def visit_Lambda(self, node: itir.Lambda, *, sym_map: ChainMap):
sym_map = sym_map.new_child()
for param in node.params:
sym_map[str(param.id)] = self._uids.sequential_id()

return im.lambda_(*sym_map.values())(self.visit(node.expr, sym_map=sym_map))

def visit_SymRef(self, node: itir.SymRef, *, sym_map: dict[str, str]):
return im.ref(sym_map[node.id]) if node.id in sym_map else node


def is_equal(a: itir.Expr, b: itir.Expr):
"""
Return true if two expressions have provably equal values.
Be aware that this function might return false even though the two expression have the same
value.
>>> testee1 = im.lambda_("a")(im.plus("a", "b"))
>>> testee2 = im.lambda_("c")(im.plus("c", "b"))
>>> assert is_equal(testee1, testee2)
>>> testee1 = im.lambda_("a")(im.plus("a", "b"))
>>> testee2 = im.lambda_("c")(im.plus("c", "d"))
>>> assert not is_equal(testee1, testee2)
"""
# TODO(tehrengruber): Extend this function cover more cases than just those with equal
# structure, e.g., by also canonicalization of the structure.
return a == b or (
CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b)
)
Loading

0 comments on commit 2970575

Please sign in to comment.