Skip to content

Commit

Permalink
Fix broken iterator tests containing lifts
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Dec 2, 2024
1 parent 22dbf5e commit bfacda6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,11 @@ def fuse_as_fieldop(
def _arg_inline_predicate(node: itir.Expr, shifts):
if _is_tuple_expr_of_literals(node):
return True
if (is_applied_fieldop := cpm.is_applied_as_fieldop(node)) or cpm.is_call_to(node, "if_"):
# TODO(tehrengruber): write test case ensuring scan is not tried to be inlined (e.g. test_call_scan_operator_from_field_operator)
if (
is_applied_fieldop := cpm.is_applied_as_fieldop(node)
and not cpm.is_call_to(node.fun.args[0], "scan") # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop
) or cpm.is_call_to(node, "if_"):
# always inline arg if it is an applied fieldop with only a single arg
if is_applied_fieldop and len(node.args) == 1:
return True
Expand Down Expand Up @@ -264,7 +268,10 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs):
# TODO(tehrengruber): Write test-case. E.g. Adding two sparse fields. Sara observed this
# with a cast to a sparse field, but this is likely already covered.
if cpm.is_let(node):
eligible_args = [isinstance(arg.type, ts.FieldType) and isinstance(arg.type.dtype, it_ts.ListType) for arg in node.args]
eligible_args = [
isinstance(arg.type, ts.FieldType) and isinstance(arg.type.dtype, it_ts.ListType)
for arg in node.args
]
if any(eligible_args):
node = inline_lambdas.inline_lambda(node, eligible_params=eligible_args)
return self.visit(node)
Expand Down
4 changes: 4 additions & 0 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def apply_common_transforms(
ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program
ir = NormalizeShifts().visit(ir)

# TODO(tehrengruber): Many iterator test contain lifts that need to be inlined, e.g.
# test_can_deref. We didn't notice previously as FieldOpFusion did this implicitly everywhere.
ir = inline_lifts.InlineLifts().visit(ir)

# note: this increases the size of the tree
# Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)`
ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True)
Expand Down

0 comments on commit bfacda6

Please sign in to comment.