Skip to content

Commit

Permalink
Add test and fix nested transformation on nested ifs
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Dec 10, 2024
1 parent 914a9e5 commit fc46edf
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
10 changes: 9 additions & 1 deletion src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool:
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args)
if cpm.is_call_to(node, "tuple_get"):
return _is_trivial_or_tuple_thereof_expr(node.args[1])
if cpm.is_call_to(node, "if_"):
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args[1:])
if isinstance(node, (ir.SymRef, ir.Literal)):
return True
if cpm.is_let(node):
Expand Down Expand Up @@ -229,7 +231,9 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]:
method = getattr(self, f"transform_{transformation.name.lower()}")
result = method(node, **kwargs)
if result is not None:
assert result is not node # transformation should have returned None, since nothing changed
assert (
result is not node
) # transformation should have returned None, since nothing changed
itir_type_inference.reinfer(result)
return result
return None
Expand Down Expand Up @@ -400,6 +404,8 @@ def transform_propagate_to_if_on_tuples_cps(
# anything compared to regular `propagate_to_if_on_tuples`. Not inling also
# works, but we don't want bound lambda functions in our tree (at least right
# now).
# TODO(tehrengruber): `if_` of trivial expression is also considered fine. This
# will duplicate the condition and unnecessarily increase the size of the tree.
if not _is_trivial_or_tuple_thereof_expr(new_f_body):
continue
f = im.lambda_(*f_params)(new_f_body)
Expand All @@ -426,6 +432,8 @@ def transform_propagate_to_if_on_tuples_cps(
)
return new_node

return None

def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if cpm.is_let(node):
# `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,23 @@ def test_if_make_tuple_reorder_cps():
assert actual == expected


def test_if_make_tuple_reorder_cps():
def test_nested_if_make_tuple_reorder_cps():
testee = im.let(
("t1", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4))),
("t2", im.if_(False, im.make_tuple(5, 6), im.make_tuple(7, 8)))
("t2", im.if_(False, im.make_tuple(5, 6), im.make_tuple(7, 8))),
)(
im.make_tuple(im.tuple_get(1, "t"), im.tuple_get(0, "t"))
im.make_tuple(
im.tuple_get(1, "t1"),
im.tuple_get(0, "t1"),
im.tuple_get(1, "t2"),
im.tuple_get(0, "t2"),
)
)
expected = im.if_(
True,
im.if_(False, im.make_tuple(2, 1, 6, 5), im.make_tuple(2, 1, 8, 7)),
im.if_(False, im.make_tuple(4, 3, 6, 5), im.make_tuple(4, 3, 8, 7)),
)
expected = im.if_(True, im.if_(False, im.make_tuple(2, 1), im.make_tuple(4, 3)))
actual = CollapseTuple.apply(
testee,
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,
Expand Down

0 comments on commit fc46edf

Please sign in to comment.