Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next]: Improve CollapseTuple pass #1350

Merged
merged 31 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
fe49d60
Improve CollapseTuple pass
tehrengruber Oct 17, 2023
cd7b6e3
Add comments
tehrengruber Oct 17, 2023
e49eae0
Bugfixes for scan
tehrengruber Oct 18, 2023
4e750b4
Introduce `_is_equal_value_heuristics` to avoid double visit
tehrengruber Oct 21, 2023
29560a3
Merge origin/main
tehrengruber Nov 27, 2023
e6c9f44
Fix tests
tehrengruber Nov 28, 2023
5a27524
Fix tests
tehrengruber Nov 28, 2023
367a05f
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Nov 28, 2023
0437fce
Move iterator utils to dedicated module (in preparation for other PR …
tehrengruber Nov 28, 2023
36039a4
Fix format
tehrengruber Nov 28, 2023
ea42a16
Merge branch 'refactor_ir_utils' into improve_collapse_tuple
tehrengruber Nov 28, 2023
8772906
Fix tests
tehrengruber Nov 29, 2023
ace2dc0
Fix tests
tehrengruber Nov 29, 2023
2621f89
Merge branch 'refactor_ir_utils' into improve_collapse_tuple
tehrengruber Nov 29, 2023
8b7a6d7
Fix tests
tehrengruber Nov 29, 2023
b5fd847
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Nov 29, 2023
6113312
Fix tests
tehrengruber Nov 29, 2023
6680075
Merge origin/main
tehrengruber Jan 4, 2024
3367c28
Small fix
tehrengruber Jan 4, 2024
c22cd35
Merge origin/main
tehrengruber Jan 25, 2024
7c15360
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Feb 3, 2024
3a2a007
Cleanup
tehrengruber Feb 6, 2024
d915af5
Cleanup
tehrengruber Feb 6, 2024
fcdb8ae
Cleanup
tehrengruber Feb 6, 2024
8d4f93e
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Feb 6, 2024
796f3a7
Revert debug changes to caching
tehrengruber Feb 7, 2024
bdc9221
Cleanup
tehrengruber Feb 7, 2024
9902d63
Address reviewer comments
tehrengruber Feb 7, 2024
650a934
Fix broken CI
tehrengruber Feb 8, 2024
00bbf2b
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Feb 8, 2024
e9f6fb1
Address reviewer comments
tehrengruber Feb 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TypeGuard

from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im


def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
Expand All @@ -24,3 +25,13 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
and isinstance(arg.fun.fun, itir.SymRef)
and arg.fun.fun.id == "lift"
)


def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expression of the form `(λ(...) → ...)(...)`."""
return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda)


def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]:
"""Match expression of the form `if_(cond, true_branch, false_branch)`."""
return isinstance(node, itir.FunCall) and node.fun == im.ref("if_")
30 changes: 23 additions & 7 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from typing import Callable, Union
import typing
from typing import Callable, Iterable, Union

from gt4py._core import definitions as core_defs
from gt4py.next.iterator import ir as itir
Expand Down Expand Up @@ -242,16 +243,31 @@ class let:
--------
>>> str(let("a", "b")("a")) # doctest: +ELLIPSIS
'(λ(a) → a)(b)'
>>> str(let("a", 1,
... "b", 2
>>> str(let(("a", 1),
... ("b", 2)
... )(plus("a", "b")))
'(λ(a, b) → a + b)(1, 2)'
"""

def __init__(self, *vars_and_values):
assert len(vars_and_values) % 2 == 0
self.vars = vars_and_values[0::2]
self.init_forms = vars_and_values[1::2]
@typing.overload
def __init__(self, var: str | itir.Sym, init_form: itir.Expr): ...

@typing.overload
def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr]]): ...

def __init__(self, *args):
if all(isinstance(arg, tuple) and len(arg) == 2 for arg in args):
assert isinstance(args, tuple)
assert all(isinstance(arg, tuple) and len(arg) == 2 for arg in args)
self.vars = [var for var, _ in args]
self.init_forms = [init_form for _, init_form in args]
elif len(args) == 2:
self.vars = [args[0]]
self.init_forms = [args[1]]
else:
raise TypeError(
"Invalid arguments: expected a variable name and an init form or a list thereof."
)

def __call__(self, form):
return call(lambda_(*self.vars)(form))(*self.init_forms)
Expand Down
79 changes: 79 additions & 0 deletions src/gt4py/next/iterator/ir_utils/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

import dataclasses
from collections import ChainMap

from gt4py import eve
from gt4py.eve import utils as eve_utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im


@dataclasses.dataclass(frozen=True)
class CannonicalizeBoundSymbolNames(eve.NodeTranslator):
"""
Given an iterator expression cannonicalize all bound symbol names.

If two such expression are in the same scope and equal so are their values.

>>> testee1 = im.lambda_("a")(im.plus("a", "b"))
>>> cannonicalized_testee1 = CannonicalizeBoundSymbolNames.apply(testee1)
>>> str(cannonicalized_testee1)
'λ(_csym_1) → _csym_1 + b'

>>> testee2 = im.lambda_("c")(im.plus("c", "b"))
>>> cannonicalized_testee2 = CannonicalizeBoundSymbolNames.apply(testee2)
>>> assert cannonicalized_testee1 == cannonicalized_testee2
"""

_uids: eve_utils.UIDGenerator = dataclasses.field(
init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_csym")
)

@classmethod
def apply(cls, node: itir.Expr):
return cls().visit(node, sym_map=ChainMap({}))

def visit_Lambda(self, node: itir.Lambda, *, sym_map: ChainMap):
sym_map = sym_map.new_child()
for param in node.params:
sym_map[str(param.id)] = self._uids.sequential_id()

return im.lambda_(*sym_map.values())(self.visit(node.expr, sym_map=sym_map))

def visit_SymRef(self, node: itir.SymRef, *, sym_map: dict[str, str]):
return im.ref(sym_map[node.id]) if node.id in sym_map else node


def is_equal(a: itir.Expr, b: itir.Expr):
"""
Return true if two expressions have provably equal values.

Be aware that this function might return false even though the two expression have the same
value.

>>> testee1 = im.lambda_("a")(im.plus("a", "b"))
>>> testee2 = im.lambda_("c")(im.plus("c", "b"))
>>> assert is_equal(testee1, testee2)

>>> testee1 = im.lambda_("a")(im.plus("a", "b"))
>>> testee2 = im.lambda_("c")(im.plus("c", "d"))
>>> assert not is_equal(testee1, testee2)
"""
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
# TODO(tehrengruber): Extend this function cover more cases than just those with equal
# structure, e.g., by also canonicalization of the structure.
return a == b or (
CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b)
)
Loading
Loading