diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index a4b074a4b6..4933307c53 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -11,10 +11,10 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +from collections.abc import Iterable from typing import TypeGuard from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: @@ -32,6 +32,21 @@ def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) -def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]: - """Match expression of the form `if_(cond, true_branch, false_branch)`.""" - return isinstance(node, itir.FunCall) and node.fun == im.ref("if_") +def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]: + """ + Match call expression to a given function. + + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> node = im.call("plus")(1, 2) + >>> is_call_to(node, "plus") + True + >>> is_call_to(node, "minus") + False + >>> is_call_to(node, ("plus", "minus")) + True + """ + if isinstance(fun, (list, tuple, set, Iterable)) and not isinstance(fun, str): + return any((is_call_to(node, f) for f in fun)) + return ( + isinstance(node, itir.FunCall) and isinstance(node.fun, itir.SymRef) and node.fun.id == fun + ) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 4b8182a781..4e4443696f 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -23,8 +23,11 @@ from gt4py.eve import utils as eve_utils from gt4py.next import type_inference from gt4py.next.iterator import ir, type_inference as it_type_inference -from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + ir_makers as im, + misc as ir_misc, +) from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda from gt4py.next.type_system import type_info @@ -66,7 +69,7 @@ def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): def _is_trivial_make_tuple_call(node: ir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" - if not (isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple")): + if not cpm.is_call_to(node, "make_tuple"): return False if not all( isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg) @@ -247,7 +250,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: # TODO(tehrengruber): extend to general symbols as long as the tail call in the let # does not capture # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` - if is_let(node.args[1]): + if cpm.is_let(node.args[1]): idx, let_expr = node.args return im.call( im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let @@ -256,7 +259,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: )( *let_expr.args # type: ignore[attr-defined] # ensured by is_let ) - elif isinstance(node.args[1], ir.FunCall) and node.args[1].fun == im.ref("if_"): + elif cpm.is_call_to(node.args[1], "if_"): idx = node.args[0] cond, true_branch, false_branch = node.args[1].args return im.if_( @@ -273,11 +276,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir. bound_vars: dict[str, ir.Expr] = {} new_args: list[ir.Expr] = [] for arg in node.args: - if ( - isinstance(node, ir.FunCall) - and node.fun == im.ref("make_tuple") - and not _is_trivial_make_tuple_call(node) - ): + if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node): el_name = self._letify_make_tuple_uids.sequential_id() new_args.append(im.ref(el_name)) bound_vars[el_name] = arg @@ -289,7 +288,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir. return None def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: - if is_let(node): + if cpm.is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] @@ -298,7 +297,7 @@ def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.N return None def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.Node]: - if not node.fun == im.ref("if_"): + if not cpm.is_call_to(node, "if_"): # TODO(tehrengruber): This significantly increases the size of the tree. Revisit. # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: @@ -306,7 +305,7 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.N # `let (b, if cond then {1, 2} else {3, 4})) b[0]` # -> `if cond then let(b, {1, 2})(b[0]) else let(b, {3, 4})(b[0])` for i, arg in enumerate(node.args): - if is_if_call(arg): + if cpm.is_call_to(arg, "if_"): cond, true_branch, false_branch = arg.args new_true_branch = self.fp_transform(_with_altered_arg(node, i, true_branch)) new_false_branch = self.fp_transform(_with_altered_arg(node, i, false_branch)) @@ -314,14 +313,14 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.N return None def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: - if is_let(node): + if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} inner_vars = {} original_inner_expr = node.fun.expr # type: ignore[attr-defined] # ensured by is_let for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let assert arg_sym not in inner_vars # TODO(tehrengruber): fix collisions - if is_let(arg): + if cpm.is_let(arg): for sym, val in zip(arg.fun.params, arg.args): # type: ignore[attr-defined] # ensured by is_let assert sym not in outer_vars # TODO(tehrengruber): fix collisions outer_vars[sym] = val @@ -337,9 +336,9 @@ def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: return None def transform_inline_trivial_let(self, node: ir.FunCall) -> Optional[ir.Node]: - if is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let + if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let - if node.fun.expr == im.ref(arg_sym.id): # type: ignore[attr-defined] # ensured by is_let + if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let return arg return None diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 42d68318a0..a3260d5a37 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -23,8 +23,7 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir, type_inference -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.pretty_printer import PrettyPrinter from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.cse import extract_subexpression @@ -139,7 +138,7 @@ class TemporaryExtractionPredicate: def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: """Determine if `expr` is an applied lift that should be extracted as a temporary.""" - if not is_applied_lift(expr): + if not cpm.is_applied_lift(expr): return False # do not extract when the result is a list (i.e. a lift expression used in a `reduce` call) # as we can not create temporaries for these stencils @@ -185,7 +184,7 @@ def _closure_parameter_argument_mapping(closure: ir.StencilClosure): to `arg`. In case the stencil is a scan, a mapping from closure inputs to scan pass (i.e. first arg is ignored) is returned. """ - is_scan = isinstance(closure.stencil, ir.FunCall) and closure.stencil.fun == im.ref("scan") + is_scan = cpm.is_call_to(closure.stencil, "scan") if is_scan: stencil = closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan @@ -242,13 +241,14 @@ def always_extract_heuristics(_): while closure_stack: current_closure: ir.StencilClosure = closure_stack.pop() - if current_closure.stencil == im.ref("deref"): + if ( + isinstance(current_closure.stencil, ir.SymRef) + and current_closure.stencil.id == "deref" + ): closures.append(current_closure) continue - is_scan: bool = isinstance( - current_closure.stencil, ir.FunCall - ) and current_closure.stencil.fun == im.ref("scan") + is_scan: bool = cpm.is_call_to(current_closure.stencil, "scan") current_closure_stencil = ( current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan ) @@ -571,7 +571,7 @@ def update_domains( def _tuple_constituents(node: ir.Expr) -> Iterable[ir.Expr]: - if isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple"): + if cpm.is_call_to(node, "make_tuple"): for arg in node.args: yield from _tuple_constituents(arg) else: @@ -625,7 +625,7 @@ def validate_no_dynamic_offsets(node: ir.Node): """Vaidate we have no dynamic offsets, e.g. `shift(Ioff, deref(...))(...)`""" for call_node in node.walk_values().if_isinstance(ir.FunCall): assert isinstance(call_node, ir.FunCall) - if call_node.fun == im.ref("shift"): + if cpm.is_call_to(call_node, "shift"): if any(not isinstance(arg, ir.OffsetLiteral) for arg in call_node.args): raise NotImplementedError("Dynamic offsets not supported in temporary pass.") diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 21551fab6a..9f3e9d48fc 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -13,9 +13,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im # TODO(tehrengruber): This pass can be generalized to all builtins, e.g. @@ -44,23 +43,11 @@ def apply(cls, node: ir.Node): return cls().visit(node) def visit_FunCall(self, node: ir.FunCall): - if P(ir.FunCall, fun=ir.SymRef(id="deref"), args=[P(ir.FunCall, fun=P(ir.Lambda))]).match( - node - ): - builtin = node.fun - lambda_fun: ir.Lambda = node.args[0].fun # type: ignore[attr-defined] # invariant ensured by pattern match above - lambda_args: list[ir.Expr] = node.args[0].args # type: ignore[attr-defined] # invariant ensured by pattern match above - node = ir.FunCall( - fun=ir.Lambda( - params=lambda_fun.params, expr=ir.FunCall(fun=builtin, args=[lambda_fun.expr]) - ), - args=lambda_args, - ) - elif ( - node.fun == im.ref("deref") - and isinstance(node.args[0], ir.FunCall) - and node.args[0].fun == im.ref("if_") - ): + if cpm.is_call_to(node, "deref") and cpm.is_let(node.args[0]): + fun: ir.Lambda = node.args[0].fun # type: ignore[assignment] # ensured by is_let + args: list[ir.Expr] = node.args[0].args + node = im.let(*zip(fun.params, args))(im.deref(fun.expr)) # type: ignore[arg-type] # mypy not smart enough + elif cpm.is_call_to(node, "deref") and cpm.is_call_to(node.args[0], "if_"): cond, true_branch, false_branch = node.args[0].args return im.if_(cond, im.deref(true_branch), im.deref(false_branch)) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py index e2e29cd4db..899c108a98 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py @@ -16,9 +16,17 @@ from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref -def test_deref_propagation(): +def test_deref_let_propagation(): testee = im.deref(im.call(im.lambda_("inner_it")(im.lift("stencil")("inner_it")))("outer_it")) expected = im.call(im.lambda_("inner_it")(im.deref(im.lift("stencil")("inner_it"))))("outer_it") actual = PropagateDeref.apply(testee) assert actual == expected + + +def test_deref_if_propagation(): + testee = im.deref(im.if_("cond", "true_branch", "false_branch")) + expected = im.if_("cond", im.deref("true_branch"), im.deref("false_branch")) + + actual = PropagateDeref.apply(testee) + assert actual == expected