diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index b64886f729..ea7aad890c 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -28,10 +28,11 @@ from gt4py.next.type_system import type_info, type_specifications as ts -def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): +def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr | str): """Given a itir.FunCall return a new call with one of its argument replaced.""" return ir.FunCall( - fun=node.fun, args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)] + fun=node.fun, + args=[arg if i != arg_idx else im.ensure_expr(new_arg) for i, arg in enumerate(node.args)], ) @@ -47,6 +48,34 @@ def _is_trivial_make_tuple_call(node: ir.Expr): return True +def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: + """ + Return `true` if the expr is a trivial expression or tuple thereof. + + >>> _is_trivial_or_tuple_thereof_expr(im.make_tuple("a", "b")) + True + >>> _is_trivial_or_tuple_thereof_expr(im.tuple_get(1, "a")) + True + >>> _is_trivial_or_tuple_thereof_expr( + ... im.let("t", im.make_tuple("a", "b"))(im.tuple_get(1, "t")) + ... ) + True + """ + if cpm.is_call_to(node, "make_tuple"): + return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) + if cpm.is_call_to(node, "tuple_get"): + return _is_trivial_or_tuple_thereof_expr(node.args[1]) + if cpm.is_call_to(node, "if_"): + return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args[1:]) + if isinstance(node, (ir.SymRef, ir.Literal)): + return True + if cpm.is_let(node): + return _is_trivial_or_tuple_thereof_expr(node.fun.expr) and all( # type: ignore[attr-defined] # ensured by is_let + _is_trivial_or_tuple_thereof_expr(arg) for arg in node.args + ) + return False + + # TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first, # transform each node until no transformations apply anymore, whenever a node is to be transformed # go through all available transformation and apply them. However the final result here still @@ -76,28 +105,42 @@ class Flag(enum.Flag): #: `let(tup, {trivial_expr1, trivial_expr2})(foo(tup))` #: -> `foo({trivial_expr1, trivial_expr2})` INLINE_TRIVIAL_MAKE_TUPLE = enum.auto() + #: Similar as `PROPAGATE_TO_IF_ON_TUPLES`, but propagates in the opposite direction, i.e. + #: into the tree, allowing removal of tuple expressions across `if_` calls without + #: increasing the size of the tree. This is particularly important for `if` statements + #: in the frontend, where outwards propagation can have devastating effects on the tree + #: size, without any gained optimization potential. For example + #: ``` + #: complex_lambda(if cond1 + #: if cond2 + #: {...} + #: else: + #: {...} + #: else + #: {...}) + #: ``` + #: is problematic, since `PROPAGATE_TO_IF_ON_TUPLES` would propagate and hence duplicate + #: `complex_lambda` three times, while we only want to get rid of the tuple expressions + #: inside of the `if_`s. + #: Note that this transformation is not mutually exclusive to `PROPAGATE_TO_IF_ON_TUPLES`. + PROPAGATE_TO_IF_ON_TUPLES_CPS = enum.auto() #: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` PROPAGATE_TO_IF_ON_TUPLES = enum.auto() #: `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` PROPAGATE_NESTED_LET = enum.auto() - #: `let(a, 1)(a)` -> `1` + #: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)` INLINE_TRIVIAL_LET = enum.auto() @classmethod def all(self) -> CollapseTuple.Flag: return functools.reduce(operator.or_, self.__members__.values()) + uids: eve_utils.UIDGenerator ignore_tuple_size: bool flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] PRESERVED_ANNEX_ATTRS = ("type",) - # we use one UID generator per instance such that the generated ids are - # stable across multiple runs (required for caching to properly work) - _letify_make_tuple_uids: eve_utils.UIDGenerator = dataclasses.field( - init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_tuple_el") - ) - @classmethod def apply( cls, @@ -111,6 +154,7 @@ def apply( flags: Optional[Flag] = None, # allow sym references without a symbol declaration, mostly for testing allow_undeclared_symbols: bool = False, + uids: Optional[eve_utils.UIDGenerator] = None, ) -> ir.Node: """ Simplifies `make_tuple`, `tuple_get` calls. @@ -127,6 +171,7 @@ def apply( """ flags = flags or cls.flags offset_provider_type = offset_provider_type or {} + uids = uids or eve_utils.UIDGenerator() if isinstance(node, ir.Program): within_stencil = False @@ -145,6 +190,7 @@ def apply( new_node = cls( ignore_tuple_size=ignore_tuple_size, flags=flags, + uids=uids, ).visit(node, within_stencil=within_stencil) # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important @@ -185,6 +231,10 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: method = getattr(self, f"transform_{transformation.name.lower()}") result = method(node, **kwargs) if result is not None: + assert ( + result is not node + ) # transformation should have returned None, since nothing changed + itir_type_inference.reinfer(result) return result return None @@ -263,13 +313,13 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Op if node.fun == ir.SymRef(id="make_tuple"): # `make_tuple(expr1, expr1)` # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` - bound_vars: dict[str, ir.Expr] = {} + bound_vars: dict[ir.Sym, ir.Expr] = {} new_args: list[ir.Expr] = [] for arg in node.args: if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node): - el_name = self._letify_make_tuple_uids.sequential_id() - new_args.append(im.ref(el_name)) - bound_vars[el_name] = arg + el_name = self.uids.sequential_id(prefix="__ct_el") + new_args.append(im.ref(el_name, arg.type)) + bound_vars[im.sym(el_name, arg.type)] = arg else: new_args.append(arg) @@ -312,6 +362,78 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt return im.if_(cond, new_true_branch, new_false_branch) return None + def transform_propagate_to_if_on_tuples_cps( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: + if cpm.is_call_to(node, "if_"): + return None + + for i, arg in enumerate(node.args): + if cpm.is_call_to(arg, "if_"): + itir_type_inference.reinfer(arg) + if not any(isinstance(branch.type, ts.TupleType) for branch in arg.args[1:]): + continue + + cond, true_branch, false_branch = arg.args + tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above + tuple_len = len(tuple_type.types) + + # transform function into continuation-passing-style + itir_type_inference.reinfer(node) + assert node.type + f_type = ts.FunctionType( + pos_only_args=tuple_type.types, + pos_or_kw_args={}, + kw_only_args={}, + returns=node.type, + ) + f_params = [ + im.sym(self.uids.sequential_id(prefix="__ct_el_cps"), type_) + for type_ in tuple_type.types + ] + f_args = [im.ref(param.id, param.type) for param in f_params] + f_body = _with_altered_arg(node, i, im.make_tuple(*f_args)) + # simplify, e.g., inline trivial make_tuple args + new_f_body = self.fp_transform(f_body, **kwargs) + # if the function did not simplify there is nothing to gain. Skip + # transformation. + if new_f_body is f_body: + continue + # if the function is not trivial the transformation would still work, but + # inlining would result in a larger tree again and we didn't didn't gain + # anything compared to regular `propagate_to_if_on_tuples`. Not inling also + # works, but we don't want bound lambda functions in our tree (at least right + # now). + # TODO(tehrengruber): `if_` of trivial expression is also considered fine. This + # will duplicate the condition and unnecessarily increase the size of the tree. + if not _is_trivial_or_tuple_thereof_expr(new_f_body): + continue + f = im.lambda_(*f_params)(new_f_body) + + tuple_var = self.uids.sequential_id(prefix="__ct_tuple_cps") + f_var = self.uids.sequential_id(prefix="__ct_cont") + new_branches = [] + for branch in arg.args[1:]: + new_branch = im.let(tuple_var, branch)( + im.call(im.ref(f_var, f_type))( + *( + im.tuple_get(i, im.ref(tuple_var, branch.type)) + for i in range(tuple_len) + ) + ) + ) + new_branches.append(self.fp_transform(new_branch, **kwargs)) + + new_node = im.let(f_var, f)(im.if_(cond, *new_branches)) + new_node = inline_lambda(new_node, eligible_params=[True]) + assert cpm.is_call_to(new_node, "if_") + new_node = im.if_( + cond, *(self.fp_transform(branch, **kwargs) for branch in new_node.args[1:]) + ) + return new_node + + return None + def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` @@ -339,9 +461,13 @@ def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional return None def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let - # `let(a, 1)(a)` -> `1` - for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let - if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let - return arg + if cpm.is_let(node): + if isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let + # `let(a, 1)(a)` -> `1` + for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let + if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let + return arg + if any(trivial_args := [isinstance(arg, (ir.SymRef, ir.Literal)) for arg in node.args]): + return inline_lambda(node, eligible_params=trivial_args) + return None diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index d967c8fbb8..f3cb0cc468 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -62,6 +62,7 @@ def apply_common_transforms( tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") mergeasfop_uids = eve_utils.UIDGenerator() + collapse_tuple_uids = eve_utils.UIDGenerator() ir = MergeLet().visit(ir) ir = inline_fundefs.InlineFundefs().visit(ir) @@ -73,7 +74,12 @@ def apply_common_transforms( # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) - ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + ir = CollapseTuple.apply( + ir, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + uids=collapse_tuple_uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet @@ -90,7 +96,12 @@ def apply_common_transforms( inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # always an itir.Program # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply(inlined, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + inlined = CollapseTuple.apply( + inlined, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + uids=collapse_tuple_uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program inlined = InlineScalar.apply(inlined, offset_provider_type=offset_provider_type) # This pass is required to run after CollapseTuple as otherwise we can not inline @@ -122,7 +133,10 @@ def apply_common_transforms( # only run the unconditional version here instead of in the loop above. if unconditionally_collapse_tuples: ir = CollapseTuple.apply( - ir, ignore_tuple_size=True, offset_provider_type=offset_provider_type + ir, + ignore_tuple_size=True, + uids=collapse_tuple_uids, + offset_provider_type=offset_provider_type, ) # type: ignore[assignment] # always an itir.Program ir = NormalizeShifts().visit(ir) @@ -160,7 +174,9 @@ def apply_fieldview_transforms( ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) ir = CollapseTuple.apply( - ir, offset_provider_type=common.offset_provider_to_type(offset_provider) + ir, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + offset_provider_type=common.offset_provider_to_type(offset_provider), ) # type: ignore[assignment] # type is still `itir.Program` ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 1b980783fa..9f7a14b0b8 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -292,7 +292,9 @@ def type_synthesizer(*args, **kwargs): assert type_info.accepts_args(fun_type, with_args=list(args), with_kwargs=kwargs) return fun_type.returns - return type_synthesizer + return ObservableTypeSynthesizer( + type_synthesizer=type_synthesizer, store_inferred_type_in_node=False + ) class SanitizeTypes(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): @@ -312,6 +314,15 @@ def visit_Node(self, node: itir.Node, *, symtable: dict[str, itir.Node]) -> itir T = TypeVar("T", bound=itir.Node) +_INITIAL_CONTEXT = { + name: ObservableTypeSynthesizer( + type_synthesizer=type_synthesizer.builtin_type_synthesizers[name], + # builtin functions are polymorphic + store_inferred_type_in_node=False, + ) + for name in type_synthesizer.builtin_type_synthesizers.keys() +} + @dataclasses.dataclass class ITIRTypeInference(eve.NodeTranslator): @@ -323,11 +334,13 @@ class ITIRTypeInference(eve.NodeTranslator): PRESERVED_ANNEX_ATTRS = ("domain",) - offset_provider_type: common.OffsetProviderType + offset_provider_type: Optional[common.OffsetProviderType] #: Mapping from a dimension name to the actual dimension instance. - dimensions: dict[str, common.Dimension] + dimensions: Optional[dict[str, common.Dimension]] #: Allow sym refs to symbols that have not been declared. Mostly used in testing. allow_undeclared_symbols: bool + #: Reinference-mode skipping already typed nodes. + reinfer: bool @classmethod def apply( @@ -420,24 +433,44 @@ def apply( ) ), allow_undeclared_symbols=allow_undeclared_symbols, + reinfer=False, ) if not inplace: node = copy.deepcopy(node) - instance.visit( - node, - ctx={ - name: ObservableTypeSynthesizer( - type_synthesizer=type_synthesizer.builtin_type_synthesizers[name], - # builtin functions are polymorphic - store_inferred_type_in_node=False, - ) - for name in type_synthesizer.builtin_type_synthesizers.keys() - }, + instance.visit(node, ctx=_INITIAL_CONTEXT) + return node + + @classmethod + def apply_reinfer(cls, node: T) -> T: + """ + Given a partially typed node infer the type of ``node`` and its sub-nodes. + + Contrary to the regular inference, this method does not descend into already typed sub-nodes + and can be used as a lightweight way to restore type information during a pass. + + Note that this function is stateful, which is usually desired, and more performant. + + Arguments: + node: The :class:`itir.Node` to infer the types of. + """ + if node.type: # already inferred + return node + + instance = cls( + offset_provider_type=None, dimensions=None, allow_undeclared_symbols=True, reinfer=True ) + instance.visit(node, ctx=_INITIAL_CONTEXT) return node def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + # we found a node that is typed, do not descend into children + if self.reinfer and isinstance(node, itir.Node) and node.type: + if isinstance(node.type, ts.FunctionType): + return _type_synthesizer_from_function_type(node.type) + return node.type + result = super().visit(node, **kwargs) + if isinstance(node, itir.Node): if isinstance(result, ts.TypeSpec): if node.type and not isinstance(node.type, ts.DeferredType): @@ -519,19 +552,22 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: ) def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionType: - assert ( - node.value in self.dimensions - ), f"Dimension {node.value} not present in offset provider." - return ts.DimensionType(dim=self.dimensions[node.value]) + return ts.DimensionType(dim=common.Dimension(value=node.value, kind=node.kind)) # TODO: revisit what we want to do with OffsetLiterals as we already have an Offset type in # the frontend. - def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs) -> it_ts.OffsetLiteralType: + def visit_OffsetLiteral( + self, node: itir.OffsetLiteral, **kwargs + ) -> it_ts.OffsetLiteralType | ts.DeferredType: + if self.reinfer: + return ts.DeferredType(constraint=it_ts.OffsetLiteralType) + if _is_representable_as_int(node.value): return it_ts.OffsetLiteralType( value=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())) ) else: + assert isinstance(self.dimensions, dict) assert isinstance(node.value, str) and node.value in self.dimensions return it_ts.OffsetLiteralType(value=self.dimensions[node.value]) @@ -608,3 +644,5 @@ def visit_Node(self, node: itir.Node, **kwargs): infer = ITIRTypeInference.apply + +reinfer = ITIRTypeInference.apply_reinfer diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 5be9ed7438..5fc78a7c6f 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -94,6 +94,10 @@ def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType: @_register_builtin_type_synthesizer(fun_names=itir.BINARY_MATH_NUMBER_BUILTINS) def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType: + if isinstance(lhs, ts.DeferredType): + return rhs + if isinstance(rhs, ts.DeferredType): + return lhs assert lhs == rhs return lhs diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 7eb4e86adb..e333015152 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -5,6 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import copy # TODO: test failure when something is not typed after inference is run # TODO: test lift with no args @@ -15,6 +16,7 @@ import pytest +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.type_system import ( @@ -80,7 +82,9 @@ def expression_test_cases(): (im.call("make_const_list")(True), it_ts.ListType(element_type=bool_type)), (im.call("list_get")(0, im.ref("l", it_ts.ListType(element_type=bool_type))), bool_type), ( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ), it_ts.NamedRangeType(dim=Vertex), ), ( @@ -91,7 +95,9 @@ def expression_test_cases(): ), ( im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1) + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ) ), it_ts.DomainType(dims=[Vertex]), ), @@ -157,8 +163,14 @@ def expression_test_cases(): im.call("as_fieldop")( im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), - im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), + 0, + 1, + ), + im.call("named_range")( + itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 + ), ), ) )(im.ref("inp", float_edge_k_field)), @@ -309,8 +321,12 @@ def test_cartesian_fencil_definition(): def test_unstructured_fencil_definition(): mesh = simple_mesh() unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), - im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ), + im.call("named_range")( + itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 + ), ) testee = itir.Program( @@ -376,8 +392,12 @@ def test_function_definition(): def test_fencil_with_nb_field_input(): mesh = simple_mesh() unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), - im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ), + im.call("named_range")( + itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 + ), ) testee = itir.Program( @@ -501,3 +521,21 @@ def test_as_fieldop_without_domain(): assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( position_dims="unknown", defined_dims=float_i_field.dims, element_type=float_i_field.dtype ) + + +def test_reinference(): + testee = im.make_tuple(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)) + result = itir_type_inference.reinfer(copy.deepcopy(testee)) + assert result.type == ts.TupleType(types=[float_i_field, float_i_field]) + + +def test_func_reinference(): + f_type = ts.FunctionType( + pos_only_args=[], + pos_or_kw_args={}, + kw_only_args={}, + returns=float_i_field, + ) + testee = im.call(im.ref("f", f_type))() + result = itir_type_inference.reinfer(copy.deepcopy(testee)) + assert result.type == float_i_field diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 28090ff1e2..5e2c07ef0a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -9,6 +9,7 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.type_system import type_specifications as ts +from tests.next_tests.unit_tests.iterator_tests.test_type_inference import int_type def test_simple_make_tuple_tuple_get(): @@ -127,8 +128,8 @@ def test_letify_make_tuple_elements(): # anything that is not trivial, i.e. a SymRef, works here el1, el2 = im.let("foo", "foo")("foo"), im.let("bar", "bar")("bar") testee = im.make_tuple(el1, el2) - expected = im.let(("_tuple_el_1", el1), ("_tuple_el_2", el2))( - im.make_tuple("_tuple_el_1", "_tuple_el_2") + expected = im.let(("__ct_el_1", el1), ("__ct_el_2", el2))( + im.make_tuple("__ct_el_1", "__ct_el_2") ) actual = CollapseTuple.apply( @@ -239,3 +240,74 @@ def test_tuple_get_on_untyped_ref(): actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True, within_stencil=False) assert actual == testee + + +def test_if_make_tuple_reorder_cps(): + testee = im.let("t", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.make_tuple(im.tuple_get(1, "t"), im.tuple_get(0, "t")) + ) + expected = im.if_(True, im.make_tuple(2, 1), im.make_tuple(4, 3)) + actual = CollapseTuple.apply( + testee, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_nested_if_make_tuple_reorder_cps(): + testee = im.let( + ("t1", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4))), + ("t2", im.if_(False, im.make_tuple(5, 6), im.make_tuple(7, 8))), + )( + im.make_tuple( + im.tuple_get(1, "t1"), + im.tuple_get(0, "t1"), + im.tuple_get(1, "t2"), + im.tuple_get(0, "t2"), + ) + ) + expected = im.if_( + True, + im.if_(False, im.make_tuple(2, 1, 6, 5), im.make_tuple(2, 1, 8, 7)), + im.if_(False, im.make_tuple(4, 3, 6, 5), im.make_tuple(4, 3, 8, 7)), + ) + actual = CollapseTuple.apply( + testee, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_if_make_tuple_reorder_cps_nested(): + testee = im.let("t", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.let("c", im.tuple_get(0, "t"))( + im.make_tuple(im.tuple_get(1, "t"), im.tuple_get(0, "t"), "c") + ) + ) + expected = im.if_(True, im.make_tuple(2, 1, 1), im.make_tuple(4, 3, 3)) + actual = CollapseTuple.apply( + testee, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_if_make_tuple_reorder_cps_external(): + external_ref = im.tuple_get(0, im.ref("external", ts.TupleType(types=[int_type]))) + testee = im.let("t", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.make_tuple(external_ref, im.tuple_get(1, "t"), im.tuple_get(0, "t")) + ) + expected = im.if_(True, im.make_tuple(external_ref, 2, 1), im.make_tuple(external_ref, 4, 3)) + actual = CollapseTuple.apply( + testee, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected