-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature[next]: Non-tree-size-increasing collapse tuple on ifs #1762
base: main
Are you sure you want to change the base?
Conversation
def visit_OffsetLiteral( | ||
self, node: itir.OffsetLiteral, **kwargs | ||
) -> it_ts.OffsetLiteralType | ts.DeferredType: | ||
if self.reinfer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.reinfer: | |
# `self.dimensions` not available in reinference mode. Skip since we don't care anyway. | |
if self.reinfer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need more documentation of the new trafo or additionally a more detailed explanation on how it works.
if cpm.is_call_to(node, "make_tuple"): | ||
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) | ||
if cpm.is_call_to(node, "tuple_get"): | ||
return _is_trivial_or_tuple_thereof_expr(node.args[1]) | ||
if isinstance(node, (ir.SymRef, ir.Literal)): | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cpm.is_call_to(node, "make_tuple"): | |
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) | |
if cpm.is_call_to(node, "tuple_get"): | |
return _is_trivial_or_tuple_thereof_expr(node.args[1]) | |
if isinstance(node, (ir.SymRef, ir.Literal)): | |
return True | |
if isinstance(node, (ir.SymRef, ir.Literal)): | |
return True | |
if cpm.is_call_to(node, "make_tuple"): | |
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) | |
if cpm.is_call_to(node, "tuple_get"): | |
return _is_trivial_or_tuple_thereof_expr(node.args[1]) |
let's move the definition of trivial to the top
@@ -47,6 +48,32 @@ def _is_trivial_make_tuple_call(node: ir.Expr): | |||
return True | |||
|
|||
|
|||
def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: | |||
""" | |||
Return `true` if the expr is a trivial expression or tuple thereof. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return `true` if the expr is a trivial expression or tuple thereof. | |
Return `true` if the expr is a trivial expression (`SymRef` or `Literal`) or tuple thereof. |
@@ -185,6 +229,8 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: | |||
method = getattr(self, f"transform_{transformation.name.lower()}") | |||
result = method(node, **kwargs) | |||
if result is not None: | |||
assert result is not node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this assert useful?
@@ -126,7 +137,10 @@ def apply_common_transforms( | |||
# only run the unconditional version here instead of in the loop above. | |||
if unconditionally_collapse_tuples: | |||
ir = CollapseTuple.apply( | |||
ir, ignore_tuple_size=True, offset_provider_type=offset_provider_type | |||
ir, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't we need to exclude one of [PROPAGATE_TO_IF_ON_TUPLES, PROPAGATE_TO_IF_ON_TUPLES_CPS]
self, node: ir.FunCall, **kwargs | ||
) -> Optional[ir.Node]: | ||
if not cpm.is_call_to(node, "if_"): | ||
for i, arg in enumerate(node.args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we iterate over any functions args or do we know more?
@@ -312,6 +358,73 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt | |||
return im.if_(cond, new_true_branch, new_false_branch) | |||
return None | |||
|
|||
def transform_propagate_to_if_on_tuples_cps( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you break this function into smaller pieces, I am completely lost...
Contrary to the regular inference, this method does not descend into already typed sub-nodes | ||
and can be used as a lightweight way to restore type information during a pass. | ||
|
||
Note that this function is stateful, which is usually desired, and more performant. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the statefulness?
# we found a node that is typed, do not descend into children | ||
if self.reinfer and isinstance(node, itir.Node) and node.type: | ||
if isinstance(node.type, ts.FunctionType): | ||
return _type_synthesizer_from_function_type(node.type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do?
@@ -239,3 +240,48 @@ def test_tuple_get_on_untyped_ref(): | |||
|
|||
actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True, within_stencil=False) | |||
assert actual == testee | |||
|
|||
|
|||
def test_if_make_tuple_reorder_cps(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test with ifs in more than one arg?
Removing the tuple expressions across
if_
calls on ITIR has been a pain point in the past. While thePROPAGATE_TO_IF_ON_TUPLES
option of theCollapseTuplePass
works very reliably, the resulting increase in the tree size has been prohibitive. With the refactoring to GTIR this problem became much less pronounced, as we could restrict the propagation to field-level, i.e., outside of stencils, but the tree still grew exponentially in the number of references to boolean arguments used insideif_
conditions. This PR adds an additional optionPROPAGATE_TO_IF_ON_TUPLES_CPS
to theCollapseTuplePass
, which is similar to the existingPROPAGATE_TO_IF_ON_TUPLES
, but propagates in the opposite direction, i.e. into the tree. This allows removal of tuple expressions acrossif_
calls without increasing the size of the tree. This is particularly important forif
statements in the frontend, where outwards propagation can have devastating effects on the tree size, without any gained optimization potential. For exampleis problematic, since
PROPAGATE_TO_IF_ON_TUPLES
would propagate, and hence duplicate,complex_lambda
three times, while we only want to get rid of the tuple expressions inside of theif_
s. Note that this transformation is not mutually exclusive toPROPAGATE_TO_IF_ON_TUPLES
.