diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index f9bcccbb26..01ccbc8ab6 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -65,6 +65,8 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) if cpm.is_call_to(node, "tuple_get"): return _is_trivial_or_tuple_thereof_expr(node.args[1]) + if cpm.is_call_to(node, "if_"): + return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args[1:]) if isinstance(node, (ir.SymRef, ir.Literal)): return True if cpm.is_let(node): @@ -229,7 +231,9 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: method = getattr(self, f"transform_{transformation.name.lower()}") result = method(node, **kwargs) if result is not None: - assert result is not node # transformation should have returned None, since nothing changed + assert ( + result is not node + ) # transformation should have returned None, since nothing changed itir_type_inference.reinfer(result) return result return None @@ -400,6 +404,8 @@ def transform_propagate_to_if_on_tuples_cps( # anything compared to regular `propagate_to_if_on_tuples`. Not inling also # works, but we don't want bound lambda functions in our tree (at least right # now). + # TODO(tehrengruber): `if_` of trivial expression is also considered fine. This + # will duplicate the condition and unnecessarily increase the size of the tree. if not _is_trivial_or_tuple_thereof_expr(new_f_body): continue f = im.lambda_(*f_params)(new_f_body) @@ -426,6 +432,8 @@ def transform_propagate_to_if_on_tuples_cps( ) return new_node + return None + def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index f216b48856..5e2c07ef0a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -256,14 +256,23 @@ def test_if_make_tuple_reorder_cps(): assert actual == expected -def test_if_make_tuple_reorder_cps(): +def test_nested_if_make_tuple_reorder_cps(): testee = im.let( ("t1", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4))), - ("t2", im.if_(False, im.make_tuple(5, 6), im.make_tuple(7, 8))) + ("t2", im.if_(False, im.make_tuple(5, 6), im.make_tuple(7, 8))), )( - im.make_tuple(im.tuple_get(1, "t"), im.tuple_get(0, "t")) + im.make_tuple( + im.tuple_get(1, "t1"), + im.tuple_get(0, "t1"), + im.tuple_get(1, "t2"), + im.tuple_get(0, "t2"), + ) + ) + expected = im.if_( + True, + im.if_(False, im.make_tuple(2, 1, 6, 5), im.make_tuple(2, 1, 8, 7)), + im.if_(False, im.make_tuple(4, 3, 6, 5), im.make_tuple(4, 3, 8, 7)), ) - expected = im.if_(True, im.if_(False, im.make_tuple(2, 1), im.make_tuple(4, 3))) actual = CollapseTuple.apply( testee, flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,