Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next]: Improve CollapseTuple pass #1350

Merged
merged 31 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
fe49d60
Improve CollapseTuple pass
tehrengruber Oct 17, 2023
cd7b6e3
Add comments
tehrengruber Oct 17, 2023
e49eae0
Bugfixes for scan
tehrengruber Oct 18, 2023
4e750b4
Introduce `_is_equal_value_heuristics` to avoid double visit
tehrengruber Oct 21, 2023
29560a3
Merge origin/main
tehrengruber Nov 27, 2023
e6c9f44
Fix tests
tehrengruber Nov 28, 2023
5a27524
Fix tests
tehrengruber Nov 28, 2023
367a05f
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Nov 28, 2023
0437fce
Move iterator utils to dedicated module (in preparation for other PR …
tehrengruber Nov 28, 2023
36039a4
Fix format
tehrengruber Nov 28, 2023
ea42a16
Merge branch 'refactor_ir_utils' into improve_collapse_tuple
tehrengruber Nov 28, 2023
8772906
Fix tests
tehrengruber Nov 29, 2023
ace2dc0
Fix tests
tehrengruber Nov 29, 2023
2621f89
Merge branch 'refactor_ir_utils' into improve_collapse_tuple
tehrengruber Nov 29, 2023
8b7a6d7
Fix tests
tehrengruber Nov 29, 2023
b5fd847
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Nov 29, 2023
6113312
Fix tests
tehrengruber Nov 29, 2023
6680075
Merge origin/main
tehrengruber Jan 4, 2024
3367c28
Small fix
tehrengruber Jan 4, 2024
c22cd35
Merge origin/main
tehrengruber Jan 25, 2024
7c15360
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Feb 3, 2024
3a2a007
Cleanup
tehrengruber Feb 6, 2024
d915af5
Cleanup
tehrengruber Feb 6, 2024
fcdb8ae
Cleanup
tehrengruber Feb 6, 2024
8d4f93e
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Feb 6, 2024
796f3a7
Revert debug changes to caching
tehrengruber Feb 7, 2024
bdc9221
Cleanup
tehrengruber Feb 7, 2024
9902d63
Address reviewer comments
tehrengruber Feb 7, 2024
650a934
Fix broken CI
tehrengruber Feb 8, 2024
00bbf2b
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Feb 8, 2024
e9f6fb1
Address reviewer comments
tehrengruber Feb 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TypeGuard

from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im


def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
Expand All @@ -24,3 +25,13 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
and isinstance(arg.fun.fun, itir.SymRef)
and arg.fun.fun.id == "lift"
)


def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expression of the form `(λ(...) → ...)(...)`."""
return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda)


def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]:
"""Match expression of the form `if_(cond, true_branch, false_branch)`."""
return isinstance(node, itir.FunCall) and node.fun == im.ref("if_")
30 changes: 23 additions & 7 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from typing import Callable, Union
import typing
from typing import Callable, Iterable, Union

from gt4py._core import definitions as core_defs
from gt4py.next.iterator import ir as itir
Expand Down Expand Up @@ -242,16 +243,31 @@ class let:
--------
>>> str(let("a", "b")("a")) # doctest: +ELLIPSIS
'(λ(a) → a)(b)'
>>> str(let("a", 1,
... "b", 2
>>> str(let(("a", 1),
... ("b", 2)
... )(plus("a", "b")))
'(λ(a, b) → a + b)(1, 2)'
"""

def __init__(self, *vars_and_values):
assert len(vars_and_values) % 2 == 0
self.vars = vars_and_values[0::2]
self.init_forms = vars_and_values[1::2]
@typing.overload
def __init__(self, var: str | itir.Sym, init_form: itir.Expr): ...

@typing.overload
def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr]]): ...

def __init__(self, *args):
if all(isinstance(arg, tuple) and len(arg) == 2 for arg in args):
assert isinstance(args, tuple)
assert all(isinstance(arg, tuple) and len(arg) == 2 for arg in args)
self.vars = [var for var, _ in args]
self.init_forms = [init_form for _, init_form in args]
elif len(args) == 2:
self.vars = [args[0]]
self.init_forms = [args[1]]
else:
raise TypeError(
"Invalid arguments: expected a variable name and an init form or a list thereof."
)

def __call__(self, form):
return call(lambda_(*self.vars)(form))(*self.init_forms)
Expand Down
70 changes: 70 additions & 0 deletions src/gt4py/next/iterator/ir_utils/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

import dataclasses
from collections import ChainMap

from gt4py import eve
from gt4py.eve import utils as eve_utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im


@dataclasses.dataclass(frozen=True)
class CannonicalizeBoundSymbolNames(eve.NodeTranslator):
"""
Given an iterator expression cannonicalize all bound symbol names.

If two such expression are in the same scope and equal so are their values.

>>> testee1 = im.lambda_("a")(im.plus("a", "b"))
>>> cannonicalized_testee1 = CannonicalizeBoundSymbolNames.apply(testee1)
>>> str(cannonicalized_testee1)
'λ(_csym_1) → _csym_1 + b'

>>> testee2 = im.lambda_("c")(im.plus("c", "b"))
>>> cannonicalized_testee2 = CannonicalizeBoundSymbolNames.apply(testee2)
>>> assert cannonicalized_testee1 == cannonicalized_testee2
"""

_uids: eve_utils.UIDGenerator = dataclasses.field(
init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_csym")
)

@classmethod
def apply(cls, node: itir.Expr):
return cls().visit(node, sym_map=ChainMap({}))

def visit_Lambda(self, node: itir.Lambda, *, sym_map: ChainMap):
sym_map = sym_map.new_child()
for param in node.params:
sym_map[str(param.id)] = self._uids.sequential_id()

return im.lambda_(*sym_map.values())(self.visit(node.expr, sym_map=sym_map))

def visit_SymRef(self, node: itir.SymRef, *, sym_map: dict[str, str]):
return im.ref(sym_map[node.id]) if node.id in sym_map else node


def is_provable_equal(a: itir.Expr, b: itir.Expr):
"""
Return true if, bot not only if, two expression (with equal scope) have the same value.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the "bot not only if", even after bot->but

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a sufficient, but not a necessary condition. I think the formulation is correct as Return true if and only if certainly is, but I was already unhappy with it when I wrote it. I take any suggestion. I also renamed the function to is_equal, I vaguely remember we had a discussion in person about it.


>>> testee1 = im.lambda_("a")(im.plus("a", "b"))
>>> testee2 = im.lambda_("c")(im.plus("c", "b"))
>>> assert is_provable_equal(testee1, testee2)
"""
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
return a == b or (
CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b)
)
Loading
Loading