Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Dec 6, 2024
1 parent 5a892f3 commit 914a9e5
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 60 deletions.
121 changes: 61 additions & 60 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]:
method = getattr(self, f"transform_{transformation.name.lower()}")
result = method(node, **kwargs)
if result is not None:
assert result is not node
assert result is not node # transformation should have returned None, since nothing changed
itir_type_inference.reinfer(result)
return result
return None
Expand Down Expand Up @@ -361,69 +361,70 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt
def transform_propagate_to_if_on_tuples_cps(
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
if not cpm.is_call_to(node, "if_"):
for i, arg in enumerate(node.args):
if cpm.is_call_to(arg, "if_"):
itir_type_inference.reinfer(arg)
if not any(isinstance(branch.type, ts.TupleType) for branch in arg.args[1:]):
continue
if cpm.is_call_to(node, "if_"):
return None

cond, true_branch, false_branch = arg.args
tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above
tuple_len = len(tuple_type.types)
itir_type_inference.reinfer(node)
assert node.type

# transform function into continuation-passing-style
f_type = ts.FunctionType(
pos_only_args=tuple_type.types,
pos_or_kw_args={},
kw_only_args={},
returns=node.type,
)
f_params = [
im.sym(self.uids.sequential_id(prefix="__ct_el_cps"), type_)
for type_ in tuple_type.types
]
f_args = [im.ref(param.id, param.type) for param in f_params]
f_body = _with_altered_arg(node, i, im.make_tuple(*f_args))
# simplify, e.g., inline trivial make_tuple args
new_f_body = self.fp_transform(f_body, **kwargs)
# if the function did not simplify there is nothing to gain. Skip
# transformation.
if new_f_body is f_body:
continue
# if the function is not trivial the transformation would still work, but
# inlining would result in a larger tree again and we didn't didn't gain
# anything compared to regular `propagate_to_if_on_tuples`. Not inling also
# works, but we don't want bound lambda functions in our tree (at least right
# now).
if not _is_trivial_or_tuple_thereof_expr(new_f_body):
continue
f = im.lambda_(*f_params)(new_f_body)

tuple_var = self.uids.sequential_id(prefix="__ct_tuple_cps")
f_var = self.uids.sequential_id(prefix="__ct_cont")
new_branches = []
for branch in arg.args[1:]:
new_branch = im.let(tuple_var, branch)(
im.call(im.ref(f_var, f_type))(
*(
im.tuple_get(i, im.ref(tuple_var, branch.type))
for i in range(tuple_len)
)
for i, arg in enumerate(node.args):
if cpm.is_call_to(arg, "if_"):
itir_type_inference.reinfer(arg)
if not any(isinstance(branch.type, ts.TupleType) for branch in arg.args[1:]):
continue

cond, true_branch, false_branch = arg.args
tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above
tuple_len = len(tuple_type.types)

# transform function into continuation-passing-style
itir_type_inference.reinfer(node)
assert node.type
f_type = ts.FunctionType(
pos_only_args=tuple_type.types,
pos_or_kw_args={},
kw_only_args={},
returns=node.type,
)
f_params = [
im.sym(self.uids.sequential_id(prefix="__ct_el_cps"), type_)
for type_ in tuple_type.types
]
f_args = [im.ref(param.id, param.type) for param in f_params]
f_body = _with_altered_arg(node, i, im.make_tuple(*f_args))
# simplify, e.g., inline trivial make_tuple args
new_f_body = self.fp_transform(f_body, **kwargs)
# if the function did not simplify there is nothing to gain. Skip
# transformation.
if new_f_body is f_body:
continue
# if the function is not trivial the transformation would still work, but
# inlining would result in a larger tree again and we didn't didn't gain
# anything compared to regular `propagate_to_if_on_tuples`. Not inling also
# works, but we don't want bound lambda functions in our tree (at least right
# now).
if not _is_trivial_or_tuple_thereof_expr(new_f_body):
continue
f = im.lambda_(*f_params)(new_f_body)

tuple_var = self.uids.sequential_id(prefix="__ct_tuple_cps")
f_var = self.uids.sequential_id(prefix="__ct_cont")
new_branches = []
for branch in arg.args[1:]:
new_branch = im.let(tuple_var, branch)(
im.call(im.ref(f_var, f_type))(
*(
im.tuple_get(i, im.ref(tuple_var, branch.type))
for i in range(tuple_len)
)
)
new_branches.append(self.fp_transform(new_branch, **kwargs))

new_node = im.let(f_var, f)(im.if_(cond, *new_branches))
new_node = inline_lambda(new_node, eligible_params=[True])
assert cpm.is_call_to(new_node, "if_")
new_node = im.if_(
cond, *(self.fp_transform(branch, **kwargs) for branch in new_node.args[1:])
)
return new_node
return None
new_branches.append(self.fp_transform(new_branch, **kwargs))

new_node = im.let(f_var, f)(im.if_(cond, *new_branches))
new_node = inline_lambda(new_node, eligible_params=[True])
assert cpm.is_call_to(new_node, "if_")
new_node = im.if_(
cond, *(self.fp_transform(branch, **kwargs) for branch in new_node.args[1:])
)
return new_node

def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if cpm.is_let(node):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,23 @@ def test_if_make_tuple_reorder_cps():
assert actual == expected


def test_if_make_tuple_reorder_cps():
testee = im.let(
("t1", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4))),
("t2", im.if_(False, im.make_tuple(5, 6), im.make_tuple(7, 8)))
)(
im.make_tuple(im.tuple_get(1, "t"), im.tuple_get(0, "t"))
)
expected = im.if_(True, im.if_(False, im.make_tuple(2, 1), im.make_tuple(4, 3)))
actual = CollapseTuple.apply(
testee,
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,
allow_undeclared_symbols=True,
within_stencil=False,
)
assert actual == expected


def test_if_make_tuple_reorder_cps_nested():
testee = im.let("t", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))(
im.let("c", im.tuple_get(0, "t"))(
Expand Down

0 comments on commit 914a9e5

Please sign in to comment.