From 2c06bc668aaacb21075fd39b179985d9e7172ae3 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Apr 2024 14:48:37 +0200 Subject: [PATCH] refactor[next]: Use is_call_to instead of equality comparison with itir.Ref. (#1532) In the new ITIR type inference #1531 IR nodes store their type in the node itself. While we initially exclude the attribute from equality comparison we should nonetheless avoid comparison of nodes that only differ in type. This PR removes many of this occurrences. --- .../ir_utils/common_pattern_matcher.py | 23 ++++++++++--- .../iterator/transforms/collapse_tuple.py | 33 +++++++++---------- .../next/iterator/transforms/global_tmps.py | 20 +++++------ .../iterator/transforms/propagate_deref.py | 25 ++++---------- .../transforms_tests/test_propagate_deref.py | 10 +++++- 5 files changed, 60 insertions(+), 51 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index a4b074a4b6..4933307c53 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -11,10 +11,10 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +from collections.abc import Iterable from typing import TypeGuard from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: @@ -32,6 +32,21 @@ def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) -def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]: - """Match expression of the form `if_(cond, true_branch, false_branch)`.""" - return isinstance(node, itir.FunCall) and node.fun == im.ref("if_") +def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]: + """ + Match call expression to a given function. + + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> node = im.call("plus")(1, 2) + >>> is_call_to(node, "plus") + True + >>> is_call_to(node, "minus") + False + >>> is_call_to(node, ("plus", "minus")) + True + """ + if isinstance(fun, (list, tuple, set, Iterable)) and not isinstance(fun, str): + return any((is_call_to(node, f) for f in fun)) + return ( + isinstance(node, itir.FunCall) and isinstance(node.fun, itir.SymRef) and node.fun.id == fun + ) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 4b8182a781..4e4443696f 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -23,8 +23,11 @@ from gt4py.eve import utils as eve_utils from gt4py.next import type_inference from gt4py.next.iterator import ir, type_inference as it_type_inference -from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + ir_makers as im, + misc as ir_misc, +) from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda from gt4py.next.type_system import type_info @@ -66,7 +69,7 @@ def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): def _is_trivial_make_tuple_call(node: ir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" - if not (isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple")): + if not cpm.is_call_to(node, "make_tuple"): return False if not all( isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg) @@ -247,7 +250,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: # TODO(tehrengruber): extend to general symbols as long as the tail call in the let # does not capture # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` - if is_let(node.args[1]): + if cpm.is_let(node.args[1]): idx, let_expr = node.args return im.call( im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let @@ -256,7 +259,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: )( *let_expr.args # type: ignore[attr-defined] # ensured by is_let ) - elif isinstance(node.args[1], ir.FunCall) and node.args[1].fun == im.ref("if_"): + elif cpm.is_call_to(node.args[1], "if_"): idx = node.args[0] cond, true_branch, false_branch = node.args[1].args return im.if_( @@ -273,11 +276,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir. bound_vars: dict[str, ir.Expr] = {} new_args: list[ir.Expr] = [] for arg in node.args: - if ( - isinstance(node, ir.FunCall) - and node.fun == im.ref("make_tuple") - and not _is_trivial_make_tuple_call(node) - ): + if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node): el_name = self._letify_make_tuple_uids.sequential_id() new_args.append(im.ref(el_name)) bound_vars[el_name] = arg @@ -289,7 +288,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir. return None def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: - if is_let(node): + if cpm.is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] @@ -298,7 +297,7 @@ def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.N return None def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.Node]: - if not node.fun == im.ref("if_"): + if not cpm.is_call_to(node, "if_"): # TODO(tehrengruber): This significantly increases the size of the tree. Revisit. # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: @@ -306,7 +305,7 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.N # `let (b, if cond then {1, 2} else {3, 4})) b[0]` # -> `if cond then let(b, {1, 2})(b[0]) else let(b, {3, 4})(b[0])` for i, arg in enumerate(node.args): - if is_if_call(arg): + if cpm.is_call_to(arg, "if_"): cond, true_branch, false_branch = arg.args new_true_branch = self.fp_transform(_with_altered_arg(node, i, true_branch)) new_false_branch = self.fp_transform(_with_altered_arg(node, i, false_branch)) @@ -314,14 +313,14 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.N return None def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: - if is_let(node): + if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} inner_vars = {} original_inner_expr = node.fun.expr # type: ignore[attr-defined] # ensured by is_let for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let assert arg_sym not in inner_vars # TODO(tehrengruber): fix collisions - if is_let(arg): + if cpm.is_let(arg): for sym, val in zip(arg.fun.params, arg.args): # type: ignore[attr-defined] # ensured by is_let assert sym not in outer_vars # TODO(tehrengruber): fix collisions outer_vars[sym] = val @@ -337,9 +336,9 @@ def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: return None def transform_inline_trivial_let(self, node: ir.FunCall) -> Optional[ir.Node]: - if is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let + if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let - if node.fun.expr == im.ref(arg_sym.id): # type: ignore[attr-defined] # ensured by is_let + if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let return arg return None diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 42d68318a0..a3260d5a37 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -23,8 +23,7 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir, type_inference -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.pretty_printer import PrettyPrinter from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.cse import extract_subexpression @@ -139,7 +138,7 @@ class TemporaryExtractionPredicate: def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: """Determine if `expr` is an applied lift that should be extracted as a temporary.""" - if not is_applied_lift(expr): + if not cpm.is_applied_lift(expr): return False # do not extract when the result is a list (i.e. a lift expression used in a `reduce` call) # as we can not create temporaries for these stencils @@ -185,7 +184,7 @@ def _closure_parameter_argument_mapping(closure: ir.StencilClosure): to `arg`. In case the stencil is a scan, a mapping from closure inputs to scan pass (i.e. first arg is ignored) is returned. """ - is_scan = isinstance(closure.stencil, ir.FunCall) and closure.stencil.fun == im.ref("scan") + is_scan = cpm.is_call_to(closure.stencil, "scan") if is_scan: stencil = closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan @@ -242,13 +241,14 @@ def always_extract_heuristics(_): while closure_stack: current_closure: ir.StencilClosure = closure_stack.pop() - if current_closure.stencil == im.ref("deref"): + if ( + isinstance(current_closure.stencil, ir.SymRef) + and current_closure.stencil.id == "deref" + ): closures.append(current_closure) continue - is_scan: bool = isinstance( - current_closure.stencil, ir.FunCall - ) and current_closure.stencil.fun == im.ref("scan") + is_scan: bool = cpm.is_call_to(current_closure.stencil, "scan") current_closure_stencil = ( current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan ) @@ -571,7 +571,7 @@ def update_domains( def _tuple_constituents(node: ir.Expr) -> Iterable[ir.Expr]: - if isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple"): + if cpm.is_call_to(node, "make_tuple"): for arg in node.args: yield from _tuple_constituents(arg) else: @@ -625,7 +625,7 @@ def validate_no_dynamic_offsets(node: ir.Node): """Vaidate we have no dynamic offsets, e.g. `shift(Ioff, deref(...))(...)`""" for call_node in node.walk_values().if_isinstance(ir.FunCall): assert isinstance(call_node, ir.FunCall) - if call_node.fun == im.ref("shift"): + if cpm.is_call_to(call_node, "shift"): if any(not isinstance(arg, ir.OffsetLiteral) for arg in call_node.args): raise NotImplementedError("Dynamic offsets not supported in temporary pass.") diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 21551fab6a..9f3e9d48fc 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -13,9 +13,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im # TODO(tehrengruber): This pass can be generalized to all builtins, e.g. @@ -44,23 +43,11 @@ def apply(cls, node: ir.Node): return cls().visit(node) def visit_FunCall(self, node: ir.FunCall): - if P(ir.FunCall, fun=ir.SymRef(id="deref"), args=[P(ir.FunCall, fun=P(ir.Lambda))]).match( - node - ): - builtin = node.fun - lambda_fun: ir.Lambda = node.args[0].fun # type: ignore[attr-defined] # invariant ensured by pattern match above - lambda_args: list[ir.Expr] = node.args[0].args # type: ignore[attr-defined] # invariant ensured by pattern match above - node = ir.FunCall( - fun=ir.Lambda( - params=lambda_fun.params, expr=ir.FunCall(fun=builtin, args=[lambda_fun.expr]) - ), - args=lambda_args, - ) - elif ( - node.fun == im.ref("deref") - and isinstance(node.args[0], ir.FunCall) - and node.args[0].fun == im.ref("if_") - ): + if cpm.is_call_to(node, "deref") and cpm.is_let(node.args[0]): + fun: ir.Lambda = node.args[0].fun # type: ignore[assignment] # ensured by is_let + args: list[ir.Expr] = node.args[0].args + node = im.let(*zip(fun.params, args))(im.deref(fun.expr)) # type: ignore[arg-type] # mypy not smart enough + elif cpm.is_call_to(node, "deref") and cpm.is_call_to(node.args[0], "if_"): cond, true_branch, false_branch = node.args[0].args return im.if_(cond, im.deref(true_branch), im.deref(false_branch)) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py index e2e29cd4db..899c108a98 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py @@ -16,9 +16,17 @@ from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref -def test_deref_propagation(): +def test_deref_let_propagation(): testee = im.deref(im.call(im.lambda_("inner_it")(im.lift("stencil")("inner_it")))("outer_it")) expected = im.call(im.lambda_("inner_it")(im.deref(im.lift("stencil")("inner_it"))))("outer_it") actual = PropagateDeref.apply(testee) assert actual == expected + + +def test_deref_if_propagation(): + testee = im.deref(im.if_("cond", "true_branch", "false_branch")) + expected = im.if_("cond", im.deref("true_branch"), im.deref("false_branch")) + + actual = PropagateDeref.apply(testee) + assert actual == expected