-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature[next]: Non-tree-size-increasing collapse tuple on ifs #1762
base: main
Are you sure you want to change the base?
Changes from all commits
34d6040
48abc08
3790944
a8a63bf
2c44ffc
42b5817
9da19a2
bcd9e48
0a212bd
70562fe
9cee650
43f5741
7b37f1c
a0341a6
5a892f3
914a9e5
fc46edf
4e12195
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,10 +28,11 @@ | |
from gt4py.next.type_system import type_info, type_specifications as ts | ||
|
||
|
||
def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): | ||
def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr | str): | ||
"""Given a itir.FunCall return a new call with one of its argument replaced.""" | ||
return ir.FunCall( | ||
fun=node.fun, args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)] | ||
fun=node.fun, | ||
args=[arg if i != arg_idx else im.ensure_expr(new_arg) for i, arg in enumerate(node.args)], | ||
) | ||
|
||
|
||
|
@@ -47,6 +48,34 @@ def _is_trivial_make_tuple_call(node: ir.Expr): | |
return True | ||
|
||
|
||
def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: | ||
""" | ||
Return `true` if the expr is a trivial expression or tuple thereof. | ||
|
||
>>> _is_trivial_or_tuple_thereof_expr(im.make_tuple("a", "b")) | ||
True | ||
>>> _is_trivial_or_tuple_thereof_expr(im.tuple_get(1, "a")) | ||
True | ||
>>> _is_trivial_or_tuple_thereof_expr( | ||
... im.let("t", im.make_tuple("a", "b"))(im.tuple_get(1, "t")) | ||
... ) | ||
True | ||
""" | ||
if cpm.is_call_to(node, "make_tuple"): | ||
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) | ||
if cpm.is_call_to(node, "tuple_get"): | ||
return _is_trivial_or_tuple_thereof_expr(node.args[1]) | ||
if cpm.is_call_to(node, "if_"): | ||
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args[1:]) | ||
if isinstance(node, (ir.SymRef, ir.Literal)): | ||
return True | ||
if cpm.is_let(node): | ||
return _is_trivial_or_tuple_thereof_expr(node.fun.expr) and all( # type: ignore[attr-defined] # ensured by is_let | ||
_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args | ||
) | ||
return False | ||
|
||
|
||
# TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first, | ||
# transform each node until no transformations apply anymore, whenever a node is to be transformed | ||
# go through all available transformation and apply them. However the final result here still | ||
|
@@ -76,28 +105,42 @@ class Flag(enum.Flag): | |
#: `let(tup, {trivial_expr1, trivial_expr2})(foo(tup))` | ||
#: -> `foo({trivial_expr1, trivial_expr2})` | ||
INLINE_TRIVIAL_MAKE_TUPLE = enum.auto() | ||
#: Similar as `PROPAGATE_TO_IF_ON_TUPLES`, but propagates in the opposite direction, i.e. | ||
#: into the tree, allowing removal of tuple expressions across `if_` calls without | ||
#: increasing the size of the tree. This is particularly important for `if` statements | ||
#: in the frontend, where outwards propagation can have devastating effects on the tree | ||
#: size, without any gained optimization potential. For example | ||
#: ``` | ||
#: complex_lambda(if cond1 | ||
#: if cond2 | ||
#: {...} | ||
#: else: | ||
#: {...} | ||
#: else | ||
#: {...}) | ||
#: ``` | ||
#: is problematic, since `PROPAGATE_TO_IF_ON_TUPLES` would propagate and hence duplicate | ||
#: `complex_lambda` three times, while we only want to get rid of the tuple expressions | ||
#: inside of the `if_`s. | ||
#: Note that this transformation is not mutually exclusive to `PROPAGATE_TO_IF_ON_TUPLES`. | ||
PROPAGATE_TO_IF_ON_TUPLES_CPS = enum.auto() | ||
#: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` | ||
PROPAGATE_TO_IF_ON_TUPLES = enum.auto() | ||
#: `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` | ||
PROPAGATE_NESTED_LET = enum.auto() | ||
#: `let(a, 1)(a)` -> `1` | ||
#: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)` | ||
INLINE_TRIVIAL_LET = enum.auto() | ||
|
||
@classmethod | ||
def all(self) -> CollapseTuple.Flag: | ||
return functools.reduce(operator.or_, self.__members__.values()) | ||
|
||
uids: eve_utils.UIDGenerator | ||
ignore_tuple_size: bool | ||
flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] | ||
|
||
PRESERVED_ANNEX_ATTRS = ("type",) | ||
|
||
# we use one UID generator per instance such that the generated ids are | ||
# stable across multiple runs (required for caching to properly work) | ||
_letify_make_tuple_uids: eve_utils.UIDGenerator = dataclasses.field( | ||
init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_tuple_el") | ||
) | ||
|
||
@classmethod | ||
def apply( | ||
cls, | ||
|
@@ -111,6 +154,7 @@ def apply( | |
flags: Optional[Flag] = None, | ||
# allow sym references without a symbol declaration, mostly for testing | ||
allow_undeclared_symbols: bool = False, | ||
uids: Optional[eve_utils.UIDGenerator] = None, | ||
) -> ir.Node: | ||
""" | ||
Simplifies `make_tuple`, `tuple_get` calls. | ||
|
@@ -127,6 +171,7 @@ def apply( | |
""" | ||
flags = flags or cls.flags | ||
offset_provider_type = offset_provider_type or {} | ||
uids = uids or eve_utils.UIDGenerator() | ||
|
||
if isinstance(node, ir.Program): | ||
within_stencil = False | ||
|
@@ -145,6 +190,7 @@ def apply( | |
new_node = cls( | ||
ignore_tuple_size=ignore_tuple_size, | ||
flags=flags, | ||
uids=uids, | ||
).visit(node, within_stencil=within_stencil) | ||
|
||
# inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important | ||
|
@@ -185,6 +231,10 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: | |
method = getattr(self, f"transform_{transformation.name.lower()}") | ||
result = method(node, **kwargs) | ||
if result is not None: | ||
assert ( | ||
result is not node | ||
) # transformation should have returned None, since nothing changed | ||
itir_type_inference.reinfer(result) | ||
return result | ||
return None | ||
|
||
|
@@ -263,13 +313,13 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Op | |
if node.fun == ir.SymRef(id="make_tuple"): | ||
# `make_tuple(expr1, expr1)` | ||
# -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` | ||
bound_vars: dict[str, ir.Expr] = {} | ||
bound_vars: dict[ir.Sym, ir.Expr] = {} | ||
new_args: list[ir.Expr] = [] | ||
for arg in node.args: | ||
if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node): | ||
el_name = self._letify_make_tuple_uids.sequential_id() | ||
new_args.append(im.ref(el_name)) | ||
bound_vars[el_name] = arg | ||
el_name = self.uids.sequential_id(prefix="__ct_el") | ||
new_args.append(im.ref(el_name, arg.type)) | ||
bound_vars[im.sym(el_name, arg.type)] = arg | ||
else: | ||
new_args.append(arg) | ||
|
||
|
@@ -312,6 +362,78 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt | |
return im.if_(cond, new_true_branch, new_false_branch) | ||
return None | ||
|
||
def transform_propagate_to_if_on_tuples_cps( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you break this function into smaller pieces, I am completely lost... |
||
self, node: ir.FunCall, **kwargs | ||
) -> Optional[ir.Node]: | ||
if cpm.is_call_to(node, "if_"): | ||
return None | ||
|
||
for i, arg in enumerate(node.args): | ||
if cpm.is_call_to(arg, "if_"): | ||
itir_type_inference.reinfer(arg) | ||
if not any(isinstance(branch.type, ts.TupleType) for branch in arg.args[1:]): | ||
continue | ||
|
||
cond, true_branch, false_branch = arg.args | ||
tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above | ||
tuple_len = len(tuple_type.types) | ||
|
||
# transform function into continuation-passing-style | ||
itir_type_inference.reinfer(node) | ||
assert node.type | ||
f_type = ts.FunctionType( | ||
pos_only_args=tuple_type.types, | ||
pos_or_kw_args={}, | ||
kw_only_args={}, | ||
returns=node.type, | ||
) | ||
f_params = [ | ||
im.sym(self.uids.sequential_id(prefix="__ct_el_cps"), type_) | ||
for type_ in tuple_type.types | ||
] | ||
f_args = [im.ref(param.id, param.type) for param in f_params] | ||
f_body = _with_altered_arg(node, i, im.make_tuple(*f_args)) | ||
# simplify, e.g., inline trivial make_tuple args | ||
new_f_body = self.fp_transform(f_body, **kwargs) | ||
# if the function did not simplify there is nothing to gain. Skip | ||
# transformation. | ||
if new_f_body is f_body: | ||
continue | ||
# if the function is not trivial the transformation would still work, but | ||
# inlining would result in a larger tree again and we didn't didn't gain | ||
# anything compared to regular `propagate_to_if_on_tuples`. Not inling also | ||
# works, but we don't want bound lambda functions in our tree (at least right | ||
# now). | ||
# TODO(tehrengruber): `if_` of trivial expression is also considered fine. This | ||
# will duplicate the condition and unnecessarily increase the size of the tree. | ||
if not _is_trivial_or_tuple_thereof_expr(new_f_body): | ||
continue | ||
f = im.lambda_(*f_params)(new_f_body) | ||
|
||
tuple_var = self.uids.sequential_id(prefix="__ct_tuple_cps") | ||
f_var = self.uids.sequential_id(prefix="__ct_cont") | ||
new_branches = [] | ||
for branch in arg.args[1:]: | ||
new_branch = im.let(tuple_var, branch)( | ||
im.call(im.ref(f_var, f_type))( | ||
*( | ||
im.tuple_get(i, im.ref(tuple_var, branch.type)) | ||
for i in range(tuple_len) | ||
) | ||
) | ||
) | ||
new_branches.append(self.fp_transform(new_branch, **kwargs)) | ||
|
||
new_node = im.let(f_var, f)(im.if_(cond, *new_branches)) | ||
new_node = inline_lambda(new_node, eligible_params=[True]) | ||
assert cpm.is_call_to(new_node, "if_") | ||
new_node = im.if_( | ||
cond, *(self.fp_transform(branch, **kwargs) for branch in new_node.args[1:]) | ||
) | ||
return new_node | ||
|
||
return None | ||
|
||
def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: | ||
if cpm.is_let(node): | ||
# `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` | ||
|
@@ -339,9 +461,13 @@ def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional | |
return None | ||
|
||
def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: | ||
if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let | ||
# `let(a, 1)(a)` -> `1` | ||
for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let | ||
if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let | ||
return arg | ||
if cpm.is_let(node): | ||
if isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let | ||
# `let(a, 1)(a)` -> `1` | ||
for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let | ||
if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let | ||
return arg | ||
if any(trivial_args := [isinstance(arg, (ir.SymRef, ir.Literal)) for arg in node.args]): | ||
return inline_lambda(node, eligible_params=trivial_args) | ||
|
||
return None |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,7 @@ def apply_common_transforms( | |
|
||
tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") | ||
mergeasfop_uids = eve_utils.UIDGenerator() | ||
collapse_tuple_uids = eve_utils.UIDGenerator() | ||
|
||
ir = MergeLet().visit(ir) | ||
ir = inline_fundefs.InlineFundefs().visit(ir) | ||
|
@@ -73,7 +74,12 @@ def apply_common_transforms( | |
# Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` | ||
ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) | ||
# required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) | ||
ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program | ||
ir = CollapseTuple.apply( | ||
ir, | ||
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, | ||
uids=collapse_tuple_uids, | ||
offset_provider_type=offset_provider_type, | ||
) # type: ignore[assignment] # always an itir.Program | ||
ir = inline_dynamic_shifts.InlineDynamicShifts.apply( | ||
ir | ||
) # domain inference does not support dynamic offsets yet | ||
|
@@ -90,7 +96,12 @@ def apply_common_transforms( | |
inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # always an itir.Program | ||
# This pass is required to be in the loop such that when an `if_` call with tuple arguments | ||
# is constant-folded the surrounding tuple_get calls can be removed. | ||
inlined = CollapseTuple.apply(inlined, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program | ||
inlined = CollapseTuple.apply( | ||
inlined, | ||
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, | ||
uids=collapse_tuple_uids, | ||
offset_provider_type=offset_provider_type, | ||
) # type: ignore[assignment] # always an itir.Program | ||
inlined = InlineScalar.apply(inlined, offset_provider_type=offset_provider_type) | ||
|
||
# This pass is required to run after CollapseTuple as otherwise we can not inline | ||
|
@@ -122,7 +133,10 @@ def apply_common_transforms( | |
# only run the unconditional version here instead of in the loop above. | ||
if unconditionally_collapse_tuples: | ||
ir = CollapseTuple.apply( | ||
ir, ignore_tuple_size=True, offset_provider_type=offset_provider_type | ||
ir, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why don't we need to exclude one of |
||
ignore_tuple_size=True, | ||
uids=collapse_tuple_uids, | ||
offset_provider_type=offset_provider_type, | ||
) # type: ignore[assignment] # always an itir.Program | ||
|
||
ir = NormalizeShifts().visit(ir) | ||
|
@@ -160,7 +174,9 @@ def apply_fieldview_transforms( | |
ir = inline_fundefs.prune_unreferenced_fundefs(ir) | ||
ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) | ||
ir = CollapseTuple.apply( | ||
ir, offset_provider_type=common.offset_provider_to_type(offset_provider) | ||
ir, | ||
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, | ||
offset_provider_type=common.offset_provider_to_type(offset_provider), | ||
) # type: ignore[assignment] # type is still `itir.Program` | ||
ir = inline_dynamic_shifts.InlineDynamicShifts.apply( | ||
ir | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.