Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next]: Non-tree-size-increasing collapse tuple on ifs #1762

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

tehrengruber
Copy link
Contributor

@tehrengruber tehrengruber commented Dec 1, 2024

Removing the tuple expressions across if_ calls on ITIR has been a pain point in the past. While the PROPAGATE_TO_IF_ON_TUPLES option of the CollapseTuplePass works very reliably, the resulting increase in the tree size has been prohibitive. With the refactoring to GTIR this problem became much less pronounced, as we could restrict the propagation to field-level, i.e., outside of stencils, but the tree still grew exponentially in the number of references to boolean arguments used inside if_ conditions. This PR adds an additional option PROPAGATE_TO_IF_ON_TUPLES_CPS to the CollapseTuplePass, which is similar to the existing PROPAGATE_TO_IF_ON_TUPLES, but propagates in the opposite direction, i.e. into the tree. This allows removal of tuple expressions across if_ calls without increasing the size of the tree. This is particularly important for if statements in the frontend, where outwards propagation can have devastating effects on the tree size, without any gained optimization potential. For example

complex_lambda(if cond1
  if cond2
    {...}
  else:
    {...}
else
  {...})

is problematic, since PROPAGATE_TO_IF_ON_TUPLES would propagate, and hence duplicate, complex_lambda three times, while we only want to get rid of the tuple expressions inside of the if_s. Note that this transformation is not mutually exclusive to PROPAGATE_TO_IF_ON_TUPLES.

def visit_OffsetLiteral(
self, node: itir.OffsetLiteral, **kwargs
) -> it_ts.OffsetLiteralType | ts.DeferredType:
if self.reinfer:
Copy link
Contributor Author

@tehrengruber tehrengruber Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.reinfer:
# `self.dimensions` not available in reinference mode. Skip since we don't care anyway.
if self.reinfer:

Copy link
Contributor

@havogt havogt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need more documentation of the new trafo or additionally a more detailed explanation on how it works.

Comment on lines 64 to 69
if cpm.is_call_to(node, "make_tuple"):
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args)
if cpm.is_call_to(node, "tuple_get"):
return _is_trivial_or_tuple_thereof_expr(node.args[1])
if isinstance(node, (ir.SymRef, ir.Literal)):
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if cpm.is_call_to(node, "make_tuple"):
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args)
if cpm.is_call_to(node, "tuple_get"):
return _is_trivial_or_tuple_thereof_expr(node.args[1])
if isinstance(node, (ir.SymRef, ir.Literal)):
return True
if isinstance(node, (ir.SymRef, ir.Literal)):
return True
if cpm.is_call_to(node, "make_tuple"):
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args)
if cpm.is_call_to(node, "tuple_get"):
return _is_trivial_or_tuple_thereof_expr(node.args[1])

let's move the definition of trivial to the top

@@ -47,6 +48,32 @@ def _is_trivial_make_tuple_call(node: ir.Expr):
return True


def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool:
"""
Return `true` if the expr is a trivial expression or tuple thereof.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Return `true` if the expr is a trivial expression or tuple thereof.
Return `true` if the expr is a trivial expression (`SymRef` or `Literal`) or tuple thereof.

@@ -185,6 +229,8 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]:
method = getattr(self, f"transform_{transformation.name.lower()}")
result = method(node, **kwargs)
if result is not None:
assert result is not node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this assert useful?

@@ -126,7 +137,10 @@ def apply_common_transforms(
# only run the unconditional version here instead of in the loop above.
if unconditionally_collapse_tuples:
ir = CollapseTuple.apply(
ir, ignore_tuple_size=True, offset_provider_type=offset_provider_type
ir,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we need to exclude one of [PROPAGATE_TO_IF_ON_TUPLES, PROPAGATE_TO_IF_ON_TUPLES_CPS]

src/gt4py/next/iterator/transforms/collapse_tuple.py Outdated Show resolved Hide resolved
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
if not cpm.is_call_to(node, "if_"):
for i, arg in enumerate(node.args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we iterate over any functions args or do we know more?

@@ -312,6 +358,73 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt
return im.if_(cond, new_true_branch, new_false_branch)
return None

def transform_propagate_to_if_on_tuples_cps(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you break this function into smaller pieces, I am completely lost...

Contrary to the regular inference, this method does not descend into already typed sub-nodes
and can be used as a lightweight way to restore type information during a pass.

Note that this function is stateful, which is usually desired, and more performant.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the statefulness?

# we found a node that is typed, do not descend into children
if self.reinfer and isinstance(node, itir.Node) and node.type:
if isinstance(node.type, ts.FunctionType):
return _type_synthesizer_from_function_type(node.type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do?

@@ -239,3 +240,48 @@ def test_tuple_get_on_untyped_ref():

actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True, within_stencil=False)
assert actual == testee


def test_if_make_tuple_reorder_cps():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test with ifs in more than one arg?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants