Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next]: Non-tree-size-increasing collapse tuple on ifs #1762

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
153 changes: 135 additions & 18 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
from gt4py.next.type_system import type_info, type_specifications as ts


def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr):
def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr | str):
"""Given a itir.FunCall return a new call with one of its argument replaced."""
return ir.FunCall(
fun=node.fun, args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)]
fun=node.fun,
args=[arg if i != arg_idx else im.ensure_expr(new_arg) for i, arg in enumerate(node.args)],
)


Expand All @@ -47,6 +48,32 @@ def _is_trivial_make_tuple_call(node: ir.Expr):
return True


def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool:
"""
Return `true` if the expr is a trivial expression or tuple thereof.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Return `true` if the expr is a trivial expression or tuple thereof.
Return `true` if the expr is a trivial expression (`SymRef` or `Literal`) or tuple thereof.


>>> _is_trivial_or_tuple_thereof_expr(im.make_tuple("a", "b"))
True
>>> _is_trivial_or_tuple_thereof_expr(im.tuple_get(1, "a"))
True
>>> _is_trivial_or_tuple_thereof_expr(
... im.let("t", im.make_tuple("a", "b"))(im.tuple_get(1, "t"))
... )
True
"""
if cpm.is_call_to(node, "make_tuple"):
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args)
if cpm.is_call_to(node, "tuple_get"):
return _is_trivial_or_tuple_thereof_expr(node.args[1])
if isinstance(node, (ir.SymRef, ir.Literal)):
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if cpm.is_call_to(node, "make_tuple"):
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args)
if cpm.is_call_to(node, "tuple_get"):
return _is_trivial_or_tuple_thereof_expr(node.args[1])
if isinstance(node, (ir.SymRef, ir.Literal)):
return True
if isinstance(node, (ir.SymRef, ir.Literal)):
return True
if cpm.is_call_to(node, "make_tuple"):
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args)
if cpm.is_call_to(node, "tuple_get"):
return _is_trivial_or_tuple_thereof_expr(node.args[1])

let's move the definition of trivial to the top

if cpm.is_let(node):
return _is_trivial_or_tuple_thereof_expr(node.fun.expr) and all( # type: ignore[attr-defined] # ensured by is_let
_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args
)
return False


# TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first,
# transform each node until no transformations apply anymore, whenever a node is to be transformed
# go through all available transformation and apply them. However the final result here still
Expand Down Expand Up @@ -76,28 +103,42 @@ class Flag(enum.Flag):
#: `let(tup, {trivial_expr1, trivial_expr2})(foo(tup))`
#: -> `foo({trivial_expr1, trivial_expr2})`
INLINE_TRIVIAL_MAKE_TUPLE = enum.auto()
#: Similar as `PROPAGATE_TO_IF_ON_TUPLES`, but propagates in the opposite direction, i.e.
#: into the tree, allowing removal of tuple expressions across `if_` calls without
#: increasing the size of the tree. This is particularly important for `if` statements
#: in the frontend, where outwards propagation can have devastating effects on the tree
#: size, without any gained optimization potential. For example
#: ```
#: complex_lambda(if cond1
#: if cond2
#: {...}
#: else:
#: {...}
#: else
#: {...})
#: ```
#: is problematic, since `PROPAGATE_TO_IF_ON_TUPLES` would propagate and hence duplicate
#: `complex_lambda` three times, while we only want to get rid of the tuple expressions
#: inside of the `if_`s.
#: Note that this transformation is not mutually exclusive to `PROPAGATE_TO_IF_ON_TUPLES`.
PROPAGATE_TO_IF_ON_TUPLES_CPS = enum.auto()
#: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]`
PROPAGATE_TO_IF_ON_TUPLES = enum.auto()
#: `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))`
PROPAGATE_NESTED_LET = enum.auto()
#: `let(a, 1)(a)` -> `1`
#: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)`
INLINE_TRIVIAL_LET = enum.auto()

@classmethod
def all(self) -> CollapseTuple.Flag:
return functools.reduce(operator.or_, self.__members__.values())

uids: eve_utils.UIDGenerator
ignore_tuple_size: bool
flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument]

PRESERVED_ANNEX_ATTRS = ("type",)

# we use one UID generator per instance such that the generated ids are
# stable across multiple runs (required for caching to properly work)
_letify_make_tuple_uids: eve_utils.UIDGenerator = dataclasses.field(
init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_tuple_el")
)

@classmethod
def apply(
cls,
Expand All @@ -111,6 +152,7 @@ def apply(
flags: Optional[Flag] = None,
# allow sym references without a symbol declaration, mostly for testing
allow_undeclared_symbols: bool = False,
uids: Optional[eve_utils.UIDGenerator] = None,
) -> ir.Node:
"""
Simplifies `make_tuple`, `tuple_get` calls.
Expand All @@ -127,6 +169,7 @@ def apply(
"""
flags = flags or cls.flags
offset_provider_type = offset_provider_type or {}
uids = uids or eve_utils.UIDGenerator()

if isinstance(node, (ir.Program, ir.FencilDefinition)):
within_stencil = False
Expand All @@ -145,6 +188,7 @@ def apply(
new_node = cls(
ignore_tuple_size=ignore_tuple_size,
flags=flags,
uids=uids,
).visit(node, within_stencil=within_stencil)

# inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important
Expand Down Expand Up @@ -185,6 +229,8 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]:
method = getattr(self, f"transform_{transformation.name.lower()}")
result = method(node, **kwargs)
if result is not None:
assert result is not node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this assert useful?

itir_type_inference.reinfer(result)
return result
return None

Expand Down Expand Up @@ -263,13 +309,13 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Op
if node.fun == ir.SymRef(id="make_tuple"):
# `make_tuple(expr1, expr1)`
# -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))`
bound_vars: dict[str, ir.Expr] = {}
bound_vars: dict[ir.Sym, ir.Expr] = {}
new_args: list[ir.Expr] = []
for arg in node.args:
if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node):
el_name = self._letify_make_tuple_uids.sequential_id()
new_args.append(im.ref(el_name))
bound_vars[el_name] = arg
el_name = self.uids.sequential_id(prefix="__ct_el")
new_args.append(im.ref(el_name, arg.type))
bound_vars[im.sym(el_name, arg.type)] = arg
else:
new_args.append(arg)

Expand Down Expand Up @@ -312,6 +358,73 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt
return im.if_(cond, new_true_branch, new_false_branch)
return None

def transform_propagate_to_if_on_tuples_cps(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you break this function into smaller pieces, I am completely lost...

self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
if not cpm.is_call_to(node, "if_"):
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
for i, arg in enumerate(node.args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we iterate over any functions args or do we know more?

if cpm.is_call_to(arg, "if_"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how it looks, but maybe do the same here: if not cpm.is_call_to(arg, "if_"): continue

itir_type_inference.reinfer(arg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the missing type from a previous transform in the fp iteration?

if not any(isinstance(branch.type, ts.TupleType) for branch in arg.args[1:]):
continue

cond, true_branch, false_branch = arg.args
tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above
tuple_len = len(tuple_type.types)
itir_type_inference.reinfer(node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

under which conditions is this needed?

assert node.type

# transform function into continuation-passing-style
f_type = ts.FunctionType(
pos_only_args=tuple_type.types,
pos_or_kw_args={},
kw_only_args={},
returns=node.type,
)
f_params = [
im.sym(self.uids.sequential_id(prefix="__ct_el_cps"), type_)
for type_ in tuple_type.types
]
f_args = [im.ref(param.id, param.type) for param in f_params]
f_body = _with_altered_arg(node, i, im.make_tuple(*f_args))
# simplify, e.g., inline trivial make_tuple args
new_f_body = self.fp_transform(f_body, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this recursion, we handle all the other args after the current one? I am confused...

# if the function did not simplify there is nothing to gain. Skip
# transformation.
if new_f_body is f_body:
continue
# if the function is not trivial the transformation would still work, but
# inlining would result in a larger tree again and we didn't didn't gain
# anything compared to regular `propagate_to_if_on_tuples`. Not inling also
# works, but we don't want bound lambda functions in our tree (at least right
# now).
if not _is_trivial_or_tuple_thereof_expr(new_f_body):
continue
f = im.lambda_(*f_params)(new_f_body)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically here it's decided that we actually do something.


tuple_var = self.uids.sequential_id(prefix="__ct_tuple_cps")
f_var = self.uids.sequential_id(prefix="__ct_cont")
new_branches = []
for branch in arg.args[1:]:
new_branch = im.let(tuple_var, branch)(
im.call(im.ref(f_var, f_type))(
*(
im.tuple_get(i, im.ref(tuple_var, branch.type))
for i in range(tuple_len)
)
)
)
new_branches.append(self.fp_transform(new_branch, **kwargs))

new_node = im.let(f_var, f)(im.if_(cond, *new_branches))
new_node = inline_lambda(new_node, eligible_params=[True])
assert cpm.is_call_to(new_node, "if_")
new_node = im.if_(
cond, *(self.fp_transform(branch, **kwargs) for branch in new_node.args[1:])
)
return new_node
return None

def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if cpm.is_let(node):
# `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))`
Expand Down Expand Up @@ -339,9 +452,13 @@ def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional
return None

def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let
# `let(a, 1)(a)` -> `1`
for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let
if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let
return arg
if cpm.is_let(node):
if isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let
# `let(a, 1)(a)` -> `1`
for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let
if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let
return arg
if any(trivial_args := [isinstance(arg, (ir.SymRef, ir.Literal)) for arg in node.args]):
return inline_lambda(node, eligible_params=trivial_args)

return None
24 changes: 20 additions & 4 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def apply_common_transforms(

tmp_uids = eve_utils.UIDGenerator(prefix="__tmp")
mergeasfop_uids = eve_utils.UIDGenerator()
collapse_tuple_uids = eve_utils.UIDGenerator()

ir = MergeLet().visit(ir)
ir = inline_fundefs.InlineFundefs().visit(ir)
Expand All @@ -80,7 +81,12 @@ def apply_common_transforms(
# Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)`
ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True)
# required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed)
ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program
ir = CollapseTuple.apply(
ir,
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,
uids=collapse_tuple_uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program
ir = infer_domain.infer_program(
ir, # type: ignore[arg-type] # always an itir.Program
offset_provider=offset_provider,
Expand All @@ -94,7 +100,12 @@ def apply_common_transforms(
inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # always an itir.Program
# This pass is required to be in the loop such that when an `if_` call with tuple arguments
# is constant-folded the surrounding tuple_get calls can be removed.
inlined = CollapseTuple.apply(inlined, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program
inlined = CollapseTuple.apply(
inlined,
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,
uids=collapse_tuple_uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program
inlined = InlineScalar.apply(inlined, offset_provider_type=offset_provider_type)

# This pass is required to run after CollapseTuple as otherwise we can not inline
Expand Down Expand Up @@ -126,7 +137,10 @@ def apply_common_transforms(
# only run the unconditional version here instead of in the loop above.
if unconditionally_collapse_tuples:
ir = CollapseTuple.apply(
ir, ignore_tuple_size=True, offset_provider_type=offset_provider_type
ir,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we need to exclude one of [PROPAGATE_TO_IF_ON_TUPLES, PROPAGATE_TO_IF_ON_TUPLES_CPS]

ignore_tuple_size=True,
uids=collapse_tuple_uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program

ir = NormalizeShifts().visit(ir)
Expand Down Expand Up @@ -164,7 +178,9 @@ def apply_fieldview_transforms(
ir = inline_fundefs.prune_unreferenced_fundefs(ir)
ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True)
ir = CollapseTuple.apply(
ir, offset_provider_type=common.offset_provider_to_type(offset_provider)
ir,
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,
offset_provider_type=common.offset_provider_to_type(offset_provider),
) # type: ignore[assignment] # type is still `itir.Program`
ir = infer_domain.infer_program(ir, offset_provider=offset_provider)
return ir
Loading