Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next]: Improve CollapseTuple pass #1350

Merged
merged 31 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
fe49d60
Improve CollapseTuple pass
tehrengruber Oct 17, 2023
cd7b6e3
Add comments
tehrengruber Oct 17, 2023
e49eae0
Bugfixes for scan
tehrengruber Oct 18, 2023
4e750b4
Introduce `_is_equal_value_heuristics` to avoid double visit
tehrengruber Oct 21, 2023
29560a3
Merge origin/main
tehrengruber Nov 27, 2023
e6c9f44
Fix tests
tehrengruber Nov 28, 2023
5a27524
Fix tests
tehrengruber Nov 28, 2023
367a05f
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Nov 28, 2023
0437fce
Move iterator utils to dedicated module (in preparation for other PR …
tehrengruber Nov 28, 2023
36039a4
Fix format
tehrengruber Nov 28, 2023
ea42a16
Merge branch 'refactor_ir_utils' into improve_collapse_tuple
tehrengruber Nov 28, 2023
8772906
Fix tests
tehrengruber Nov 29, 2023
ace2dc0
Fix tests
tehrengruber Nov 29, 2023
2621f89
Merge branch 'refactor_ir_utils' into improve_collapse_tuple
tehrengruber Nov 29, 2023
8b7a6d7
Fix tests
tehrengruber Nov 29, 2023
b5fd847
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Nov 29, 2023
6113312
Fix tests
tehrengruber Nov 29, 2023
6680075
Merge origin/main
tehrengruber Jan 4, 2024
3367c28
Small fix
tehrengruber Jan 4, 2024
c22cd35
Merge origin/main
tehrengruber Jan 25, 2024
7c15360
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Feb 3, 2024
3a2a007
Cleanup
tehrengruber Feb 6, 2024
d915af5
Cleanup
tehrengruber Feb 6, 2024
fcdb8ae
Cleanup
tehrengruber Feb 6, 2024
8d4f93e
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Feb 6, 2024
796f3a7
Revert debug changes to caching
tehrengruber Feb 7, 2024
bdc9221
Cleanup
tehrengruber Feb 7, 2024
9902d63
Address reviewer comments
tehrengruber Feb 7, 2024
650a934
Fix broken CI
tehrengruber Feb 8, 2024
00bbf2b
Merge remote-tracking branch 'origin/main' into improve_collapse_tuple
tehrengruber Feb 8, 2024
e9f6fb1
Address reviewer comments
tehrengruber Feb 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 209 additions & 20 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
import enum
from collections import ChainMap
from typing import Callable
import hashlib

from dataclasses import dataclass
import dataclasses

from gt4py import eve
from gt4py.eve.utils import UIDGenerator
from gt4py.next import type_inference
from gt4py.next.iterator import ir, type_inference as it_type_inference
from gt4py.next.iterator import ir, type_inference as it_type_inference, ir_makers as im
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda, InlineLambdas


def _get_tuple_size(type_: type_inference.Type) -> int:
Expand All @@ -25,8 +31,72 @@ def _get_tuple_size(type_: type_inference.Type) -> int:
)
return len(type_.dtype)

def _is_let(node: ir.Node) -> bool:
return isinstance(node, ir.FunCall) and isinstance(node.fun, ir.Lambda)

@dataclass(frozen=True)
def _is_if_call(node: ir.Expr):
return isinstance(node, ir.FunCall) and node.fun == im.ref("if_")
havogt marked this conversation as resolved.
Show resolved Hide resolved

def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr):
return ir.FunCall(
fun=node.fun,
args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)]
)

def _is_trivial_make_tuple_call(node: ir.Expr):
if not (isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple")):
return False
if not all(isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg) for arg in node.args):
return False
return True

def nlet(bindings: list[tuple[ir.Sym, str], ir.Expr]):
return im.let(*[el for tup in bindings for el in tup])

def _short_hash(val: str) -> str:
return hashlib.sha1(val.encode('UTF-8')).hexdigest()[0:6]

@dataclasses.dataclass(frozen=True)
class CannonicalizeBoundSymbolNames(eve.NodeTranslator):
havogt marked this conversation as resolved.
Show resolved Hide resolved
"""
Given an iterator expression cannonicalize all bound symbol names.

If two such expression are in the same scope and equal so are their values.

>>> testee1 = im.lambda_("a")(im.plus("a", "b"))
>>> cannonicalized_testee1 = CannonicalizeBoundSymbolNames.apply(testee1)
>>> str(cannonicalized_testee1)
'λ(_csym_1) → _csym_1 + b'

>>> testee2 = im.lambda_("c")(im.plus("c", "b"))
>>> cannonicalized_testee2 = CannonicalizeBoundSymbolNames.apply(testee2)
>>> assert cannonicalized_testee1 == cannonicalized_testee2
"""
_uids: UIDGenerator = dataclasses.field(
init=False, repr=False, default_factory=lambda: UIDGenerator(prefix="_csym")
)

@classmethod
def apply(cls, node: ir.Expr):
return cls().visit(node, sym_map=ChainMap({}))

def visit_Lambda(self, node: ir.Lambda, *, sym_map: ChainMap):
sym_map = sym_map.new_child()
for param in node.params:
sym_map[str(param.id)] = self._uids.sequential_id()

return im.lambda_(*sym_map.values())(self.visit(node.expr, sym_map=sym_map))

def visit_SymRef(self, node: ir.SymRef, *, sym_map: dict[str, str]):
return im.ref(sym_map[node.id]) if node.id in sym_map else node

def _is_equal_value_heuristics(a: ir.Expr, b: ir.Expr):
"""
Return true if, bot not only if, two expression (with equal scope) have the same value.
"""
return a == b or (CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b))
havogt marked this conversation as resolved.
Show resolved Hide resolved

@dataclasses.dataclass(frozen=True)
class CollapseTuple(eve.NodeTranslator):
"""
Simplifies `make_tuple`, `tuple_get` calls.
Expand All @@ -35,40 +105,77 @@ class CollapseTuple(eve.NodeTranslator):
- `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i`
"""

class Flag(enum.IntEnum):
havogt marked this conversation as resolved.
Show resolved Hide resolved
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
#: `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t`
COLLAPSE_MAKE_TUPLE_TUPLE_GET = 1
#: `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i`
COLLAPSE_TUPLE_GET_MAKE_TUPLE = 2
#: `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))`
PROPAGATE_TUPLE_GET = 4
#: `{1, 2}` -> `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)`
LETIFY_MAKE_TUPLE_ELEMENTS = 8
#: TODO
INLINE_TRIVIAL_MAKE_TUPLE = 16
#: TODO
PROPAGATE_TO_IF_ON_TUPLES = 32
#: TODO
PROPAGATE_NESTED_LET=64
#: TODO
INLINE_TRIVIAL_LET=128

ignore_tuple_size: bool
collapse_make_tuple_tuple_get: bool
collapse_tuple_get_make_tuple: bool
flags: int = (Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET
| Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE
| Flag.PROPAGATE_TUPLE_GET
| Flag.LETIFY_MAKE_TUPLE_ELEMENTS
| Flag.INLINE_TRIVIAL_MAKE_TUPLE
| Flag.PROPAGATE_TO_IF_ON_TUPLES
| Flag.PROPAGATE_NESTED_LET
| Flag.INLINE_TRIVIAL_LET)

PRESERVED_ANNEX_ATTRS = ("type",)

_node_types: dict[int, type_inference.Type]
# we use one UID generator per instance such that the generated ids are
# stable across multiple runs (required for caching to properly work)
_letify_make_tuple_uids: UIDGenerator = dataclasses.field(
init=False, repr=False, default_factory=lambda: UIDGenerator(prefix="_tuple_el")
)

@classmethod
def apply(
cls,
node: ir.Node,
*,
ignore_tuple_size: bool = False,
# the following options are mostly for allowing separate testing of the modes
collapse_make_tuple_tuple_get: bool = True,
collapse_tuple_get_make_tuple: bool = True,
# manually passing flags is mostly for allowing separate testing of the modes
flags = None
) -> ir.Node:
"""
Simplifies `make_tuple`, `tuple_get` calls.

If `ignore_tuple_size`, apply the transformation even if length of the inner tuple
is greater than the length of the outer tuple.
"""
node_types = it_type_inference.infer_all(node)
it_type_inference.infer_all(node, save_to_annex=True)

return cls(
ignore_tuple_size,
collapse_make_tuple_tuple_get,
collapse_tuple_get_make_tuple,
node_types,
new_node = cls(
ignore_tuple_size=ignore_tuple_size,
flags=flags or cls.flags
).visit(node)

# inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important
# as otherwise two equal expressions containing a tuple will not be equal anymore
# and the CSE pass can not remove them.
# TODO: test case for `scan(lambda carry: {1, 2})` (see solve_nonhydro_stencil_52_like_z_q_tup)
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
new_node = InlineLambdas.apply(new_node, opcount_preserving=True, force_inline_lambda_args=False)

return new_node

def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
node = self.generic_visit(node, **kwargs)

if (
self.collapse_make_tuple_tuple_get
self.flags & self.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET
and node.fun == ir.SymRef(id="make_tuple")
and all(
isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get")
Expand All @@ -82,16 +189,17 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
for i, v in enumerate(node.args):
assert isinstance(v, ir.FunCall)
assert isinstance(v.args[0], ir.Literal)
if not (int(v.args[0].value) == i and v.args[1] == first_expr):
if not (int(v.args[0].value) == i and _is_equal_value_heuristics(v.args[1], first_expr)):
# tuple argument differs, just continue with the rest of the tree
return self.generic_visit(node)

if self.ignore_tuple_size or _get_tuple_size(self._node_types[id(first_expr)]) == len(
if self.ignore_tuple_size or _get_tuple_size(first_expr.annex.type) == len(
node.args
):
return first_expr

if (
self.collapse_tuple_get_make_tuple
self.flags & self.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE
and node.fun == ir.SymRef(id="tuple_get")
and isinstance(node.args[1], ir.FunCall)
and node.args[1].fun == ir.SymRef(id="make_tuple")
Expand All @@ -105,4 +213,85 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
make_tuple_call.args
), f"Index {idx} is out of bounds for tuple of size {len(make_tuple_call.args)}"
return node.args[1].args[idx]
return self.generic_visit(node)

if (
self.flags & self.Flag.PROPAGATE_TUPLE_GET
and node.fun == ir.SymRef(id="tuple_get")
and isinstance(node.args[0], ir.Literal) # TODO: extend to general symbols as long as the tail call in the let does not capture
):
# `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))`
if _is_let(node.args[1]):
idx, let_expr = node.args
return self.visit(
im.call(im.lambda_(*let_expr.fun.params)(im.tuple_get(idx, let_expr.fun.expr)))(*let_expr.args)
)
elif isinstance(node.args[1], ir.FunCall) and node.args[1].fun == im.ref("if_"):
idx = node.args[0]
cond, true_branch, false_branch = node.args[1].args
return self.visit(
im.if_(cond, im.tuple_get(idx, true_branch), im.tuple_get(idx, false_branch))
) # todo: check if visit needed

if (
self.flags & self.Flag.LETIFY_MAKE_TUPLE_ELEMENTS
and node.fun == ir.SymRef(id="make_tuple")
):
# `make_tuple(expr1, expr1)`
# -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))`
bound_vars: dict[str, ir.Expr] = {}
new_args: list[ir.Expr] = []
for i, arg in enumerate(node.args):
if isinstance(node, ir.FunCall) and node.fun == im.ref(
"make_tuple") and not _is_trivial_make_tuple_call(node):
el_name = self._letify_make_tuple_uids.sequential_id()
new_args.append(im.ref(el_name))
bound_vars[el_name] = arg
else:
new_args.append(arg)

if bound_vars:
return self.visit(im.let(*(el for item in bound_vars.items() for el in item))(
im.call(node.fun)(*new_args)))

if self.flags & self.Flag.INLINE_TRIVIAL_MAKE_TUPLE and _is_let(node):
# `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))`
# -> `foo(make_tuple(trivial_expr1, trivial_expr2))`
eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args]
if any(eligible_params):
return self.visit(inline_lambda(node, eligible_params=eligible_params))

if self.flags & self.Flag.PROPAGATE_TO_IF_ON_TUPLES and not node.fun == im.ref("if_"):
# TODO(tehrengruber): This significantly increases the size of the tree. Revisit.
# TODO(tehrengruber): Only inline if type of branch value is a tuple.
# `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]`
for i, arg in enumerate(node.args):
if _is_if_call(arg):
cond, true_branch, false_branch = arg.args
new_true_branch = self.visit(_with_altered_arg(node, i, true_branch), **kwargs)
new_false_branch = self.visit(_with_altered_arg(node, i, false_branch), **kwargs)
return im.if_(cond, new_true_branch, new_false_branch)

if self.flags & self.Flag.PROPAGATE_NESTED_LET and _is_let(node):
# `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))`
outer_vars = {}
inner_vars = {}
original_inner_expr = node.fun.expr
for arg_sym, arg in zip(node.fun.params, node.args):
assert arg_sym not in inner_vars # TODO: fix collisions
if _is_let(arg):
for sym, val in zip(arg.fun.params, arg.args):
assert sym not in outer_vars # TODO: fix collisions
outer_vars[sym] = val
inner_vars[arg_sym] = arg.fun.expr
else:
inner_vars[arg_sym] = arg
if outer_vars:
node = self.visit(nlet(tuple(outer_vars.items()))(nlet(tuple(inner_vars.items()))(original_inner_expr)))

if self.flags & self.Flag.INLINE_TRIVIAL_LET and _is_let(node) and isinstance(node.fun.expr, ir.SymRef):
# `let(a, 1)(a)` -> `1`
for arg_sym, arg in zip(node.fun.params, node.args):
if node.fun.expr == im.ref(arg_sym.id):
return arg

return node
12 changes: 12 additions & 0 deletions src/gt4py/next/iterator/transforms/inline_lambdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def inline_lambda( # noqa: C901 # see todo above
opcount_preserving=False,
force_inline_lift_args=False,
force_inline_trivial_lift_args=False,
force_inline_lambda_args=True,
eligible_params: Optional[list[bool]] = None,
):
assert isinstance(node.fun, ir.Lambda)
Expand Down Expand Up @@ -59,6 +60,12 @@ def inline_lambda( # noqa: C901 # see todo above
if is_applied_lift(arg) and len(arg.args) == 0:
eligible_params[i] = True

# TODO(tehrengruber): make configurable
if force_inline_lambda_args:
for i, arg in enumerate(node.args):
if isinstance(arg, ir.Lambda):
eligible_params[i] = True

if node.fun.params and not any(eligible_params):
return node

Expand Down Expand Up @@ -124,13 +131,16 @@ class InlineLambdas(NodeTranslator):

force_inline_trivial_lift_args: bool

force_inline_lambda_args: bool

@classmethod
def apply(
cls,
node: ir.Node,
opcount_preserving=False,
force_inline_lift_args=False,
force_inline_trivial_lift_args=False,
force_inline_lambda_args=True,
):
"""
Inline lambda calls by substituting every argument by its value.
Expand All @@ -156,6 +166,7 @@ def apply(
opcount_preserving=opcount_preserving,
force_inline_lift_args=force_inline_lift_args,
force_inline_trivial_lift_args=force_inline_trivial_lift_args,
force_inline_lambda_args=force_inline_lambda_args,
).visit(node)

def visit_FunCall(self, node: ir.FunCall):
Expand All @@ -166,6 +177,7 @@ def visit_FunCall(self, node: ir.FunCall):
opcount_preserving=self.opcount_preserving,
force_inline_lift_args=self.force_inline_lift_args,
force_inline_trivial_lift_args=self.force_inline_trivial_lift_args,
force_inline_lambda_args=self.force_inline_lambda_args,
)

return node
1 change: 1 addition & 0 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def apply_common_transforms(
# This pass is required to be in the loop such that when an `if_` call with tuple arguments
# is constant-folded the surrounding tuple_get calls can be removed.
inlined = CollapseTuple.apply(inlined)
inlined = PropagateDeref.apply(inlined) # todo: document

if inlined == ir:
break
Expand Down
6 changes: 5 additions & 1 deletion src/gt4py/next/iterator/transforms/propagate_deref.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from gt4py.eve import NodeTranslator
from gt4py.eve.pattern_matching import ObjectPattern as P
from gt4py.next.iterator import ir
from gt4py.next.iterator import ir, ir_makers as im


# TODO(tehrengruber): This pass can be generalized to all builtins, e.g.
Expand Down Expand Up @@ -56,4 +56,8 @@ def visit_FunCall(self, node: ir.FunCall):
),
args=lambda_args,
)
elif node.fun == im.ref("deref") and isinstance(node.args[0], ir.FunCall) and node.args[0].fun == im.ref("if_"):
havogt marked this conversation as resolved.
Show resolved Hide resolved
cond, true_branch, false_branch = node.args[0].args
return im.if_(cond, im.deref(true_branch), im.deref(false_branch))

return self.generic_visit(node)
2 changes: 1 addition & 1 deletion tests/next_tests/exclusion_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
GTFN_SKIP_TEST_LIST = [
(REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
(USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE),
#(USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE),
(USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE),
(USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
Expand Down
Loading
Loading