Skip to content

Commit

Permalink
refactor[next]: Use is_call_to instead of equality comparison with it…
Browse files Browse the repository at this point in the history
…ir.Ref. (#1532)

In the new ITIR type inference #1531 IR nodes store their type in the node itself. While we initially exclude the attribute from equality comparison we should nonetheless avoid comparison of nodes that only differ in type. This PR removes many of this occurrences.
  • Loading branch information
tehrengruber authored Apr 18, 2024
1 parent 9f44142 commit 2c06bc6
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 51 deletions.
23 changes: 19 additions & 4 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
from collections.abc import Iterable
from typing import TypeGuard

from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im


def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
Expand All @@ -32,6 +32,21 @@ def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]:
return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda)


def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]:
"""Match expression of the form `if_(cond, true_branch, false_branch)`."""
return isinstance(node, itir.FunCall) and node.fun == im.ref("if_")
def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]:
"""
Match call expression to a given function.
>>> from gt4py.next.iterator.ir_utils import ir_makers as im
>>> node = im.call("plus")(1, 2)
>>> is_call_to(node, "plus")
True
>>> is_call_to(node, "minus")
False
>>> is_call_to(node, ("plus", "minus"))
True
"""
if isinstance(fun, (list, tuple, set, Iterable)) and not isinstance(fun, str):
return any((is_call_to(node, f) for f in fun))
return (
isinstance(node, itir.FunCall) and isinstance(node.fun, itir.SymRef) and node.fun.id == fun
)
33 changes: 16 additions & 17 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
from gt4py.eve import utils as eve_utils
from gt4py.next import type_inference
from gt4py.next.iterator import ir, type_inference as it_type_inference
from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let
from gt4py.next.iterator.ir_utils import (
common_pattern_matcher as cpm,
ir_makers as im,
misc as ir_misc,
)
from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda
from gt4py.next.type_system import type_info

Expand Down Expand Up @@ -66,7 +69,7 @@ def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr):

def _is_trivial_make_tuple_call(node: ir.Expr):
"""Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof."""
if not (isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple")):
if not cpm.is_call_to(node, "make_tuple"):
return False
if not all(
isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg)
Expand Down Expand Up @@ -247,7 +250,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]:
# TODO(tehrengruber): extend to general symbols as long as the tail call in the let
# does not capture
# `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))`
if is_let(node.args[1]):
if cpm.is_let(node.args[1]):
idx, let_expr = node.args
return im.call(
im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let
Expand All @@ -256,7 +259,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]:
)(
*let_expr.args # type: ignore[attr-defined] # ensured by is_let
)
elif isinstance(node.args[1], ir.FunCall) and node.args[1].fun == im.ref("if_"):
elif cpm.is_call_to(node.args[1], "if_"):
idx = node.args[0]
cond, true_branch, false_branch = node.args[1].args
return im.if_(
Expand All @@ -273,11 +276,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir.
bound_vars: dict[str, ir.Expr] = {}
new_args: list[ir.Expr] = []
for arg in node.args:
if (
isinstance(node, ir.FunCall)
and node.fun == im.ref("make_tuple")
and not _is_trivial_make_tuple_call(node)
):
if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node):
el_name = self._letify_make_tuple_uids.sequential_id()
new_args.append(im.ref(el_name))
bound_vars[el_name] = arg
Expand All @@ -289,7 +288,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir.
return None

def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]:
if is_let(node):
if cpm.is_let(node):
# `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))`
# -> `foo(make_tuple(trivial_expr1, trivial_expr2))`
eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args]
Expand All @@ -298,30 +297,30 @@ def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.N
return None

def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.Node]:
if not node.fun == im.ref("if_"):
if not cpm.is_call_to(node, "if_"):
# TODO(tehrengruber): This significantly increases the size of the tree. Revisit.
# TODO(tehrengruber): Only inline if type of branch value is a tuple.
# Examples:
# `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]`
# `let (b, if cond then {1, 2} else {3, 4})) b[0]`
# -> `if cond then let(b, {1, 2})(b[0]) else let(b, {3, 4})(b[0])`
for i, arg in enumerate(node.args):
if is_if_call(arg):
if cpm.is_call_to(arg, "if_"):
cond, true_branch, false_branch = arg.args
new_true_branch = self.fp_transform(_with_altered_arg(node, i, true_branch))
new_false_branch = self.fp_transform(_with_altered_arg(node, i, false_branch))
return im.if_(cond, new_true_branch, new_false_branch)
return None

def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]:
if is_let(node):
if cpm.is_let(node):
# `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))`
outer_vars = {}
inner_vars = {}
original_inner_expr = node.fun.expr # type: ignore[attr-defined] # ensured by is_let
for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let
assert arg_sym not in inner_vars # TODO(tehrengruber): fix collisions
if is_let(arg):
if cpm.is_let(arg):
for sym, val in zip(arg.fun.params, arg.args): # type: ignore[attr-defined] # ensured by is_let
assert sym not in outer_vars # TODO(tehrengruber): fix collisions
outer_vars[sym] = val
Expand All @@ -337,9 +336,9 @@ def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]:
return None

def transform_inline_trivial_let(self, node: ir.FunCall) -> Optional[ir.Node]:
if is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let
if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let
# `let(a, 1)(a)` -> `1`
for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let
if node.fun.expr == im.ref(arg_sym.id): # type: ignore[attr-defined] # ensured by is_let
if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let
return arg
return None
20 changes: 10 additions & 10 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from gt4py.eve.utils import UIDGenerator
from gt4py.next import common
from gt4py.next.iterator import ir, type_inference
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.pretty_printer import PrettyPrinter
from gt4py.next.iterator.transforms import trace_shifts
from gt4py.next.iterator.transforms.cse import extract_subexpression
Expand Down Expand Up @@ -139,7 +138,7 @@ class TemporaryExtractionPredicate:

def __call__(self, expr: ir.Expr, num_occurences: int) -> bool:
"""Determine if `expr` is an applied lift that should be extracted as a temporary."""
if not is_applied_lift(expr):
if not cpm.is_applied_lift(expr):
return False
# do not extract when the result is a list (i.e. a lift expression used in a `reduce` call)
# as we can not create temporaries for these stencils
Expand Down Expand Up @@ -185,7 +184,7 @@ def _closure_parameter_argument_mapping(closure: ir.StencilClosure):
to `arg`. In case the stencil is a scan, a mapping from closure inputs to scan pass (i.e. first
arg is ignored) is returned.
"""
is_scan = isinstance(closure.stencil, ir.FunCall) and closure.stencil.fun == im.ref("scan")
is_scan = cpm.is_call_to(closure.stencil, "scan")

if is_scan:
stencil = closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan
Expand Down Expand Up @@ -242,13 +241,14 @@ def always_extract_heuristics(_):
while closure_stack:
current_closure: ir.StencilClosure = closure_stack.pop()

if current_closure.stencil == im.ref("deref"):
if (
isinstance(current_closure.stencil, ir.SymRef)
and current_closure.stencil.id == "deref"
):
closures.append(current_closure)
continue

is_scan: bool = isinstance(
current_closure.stencil, ir.FunCall
) and current_closure.stencil.fun == im.ref("scan")
is_scan: bool = cpm.is_call_to(current_closure.stencil, "scan")
current_closure_stencil = (
current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan
)
Expand Down Expand Up @@ -571,7 +571,7 @@ def update_domains(


def _tuple_constituents(node: ir.Expr) -> Iterable[ir.Expr]:
if isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple"):
if cpm.is_call_to(node, "make_tuple"):
for arg in node.args:
yield from _tuple_constituents(arg)
else:
Expand Down Expand Up @@ -625,7 +625,7 @@ def validate_no_dynamic_offsets(node: ir.Node):
"""Vaidate we have no dynamic offsets, e.g. `shift(Ioff, deref(...))(...)`"""
for call_node in node.walk_values().if_isinstance(ir.FunCall):
assert isinstance(call_node, ir.FunCall)
if call_node.fun == im.ref("shift"):
if cpm.is_call_to(call_node, "shift"):
if any(not isinstance(arg, ir.OffsetLiteral) for arg in call_node.args):
raise NotImplementedError("Dynamic offsets not supported in temporary pass.")

Expand Down
25 changes: 6 additions & 19 deletions src/gt4py/next/iterator/transforms/propagate_deref.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.eve.pattern_matching import ObjectPattern as P
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im


# TODO(tehrengruber): This pass can be generalized to all builtins, e.g.
Expand Down Expand Up @@ -44,23 +43,11 @@ def apply(cls, node: ir.Node):
return cls().visit(node)

def visit_FunCall(self, node: ir.FunCall):
if P(ir.FunCall, fun=ir.SymRef(id="deref"), args=[P(ir.FunCall, fun=P(ir.Lambda))]).match(
node
):
builtin = node.fun
lambda_fun: ir.Lambda = node.args[0].fun # type: ignore[attr-defined] # invariant ensured by pattern match above
lambda_args: list[ir.Expr] = node.args[0].args # type: ignore[attr-defined] # invariant ensured by pattern match above
node = ir.FunCall(
fun=ir.Lambda(
params=lambda_fun.params, expr=ir.FunCall(fun=builtin, args=[lambda_fun.expr])
),
args=lambda_args,
)
elif (
node.fun == im.ref("deref")
and isinstance(node.args[0], ir.FunCall)
and node.args[0].fun == im.ref("if_")
):
if cpm.is_call_to(node, "deref") and cpm.is_let(node.args[0]):
fun: ir.Lambda = node.args[0].fun # type: ignore[assignment] # ensured by is_let
args: list[ir.Expr] = node.args[0].args
node = im.let(*zip(fun.params, args))(im.deref(fun.expr)) # type: ignore[arg-type] # mypy not smart enough
elif cpm.is_call_to(node, "deref") and cpm.is_call_to(node.args[0], "if_"):
cond, true_branch, false_branch = node.args[0].args
return im.if_(cond, im.deref(true_branch), im.deref(false_branch))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,17 @@
from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref


def test_deref_propagation():
def test_deref_let_propagation():
testee = im.deref(im.call(im.lambda_("inner_it")(im.lift("stencil")("inner_it")))("outer_it"))
expected = im.call(im.lambda_("inner_it")(im.deref(im.lift("stencil")("inner_it"))))("outer_it")

actual = PropagateDeref.apply(testee)
assert actual == expected


def test_deref_if_propagation():
testee = im.deref(im.if_("cond", "true_branch", "false_branch"))
expected = im.if_("cond", im.deref("true_branch"), im.deref("false_branch"))

actual = PropagateDeref.apply(testee)
assert actual == expected

0 comments on commit 2c06bc6

Please sign in to comment.