Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Dec 6, 2024
2 parents c22cfc8 + ae62965 commit 59e0ed5
Show file tree
Hide file tree
Showing 8 changed files with 492 additions and 237 deletions.
209 changes: 116 additions & 93 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall:
if cpm.is_ref_to(stencil, "deref"):
stencil = im.lambda_("arg")(im.deref("arg"))
new_expr = im.as_fieldop(stencil, domain)(*expr.args)
type_inference.copy_type(from_=expr, to=new_expr)
type_inference.copy_type(from_=expr, to=new_expr, allow_untyped=True)

return new_expr

Expand All @@ -68,6 +68,107 @@ def _is_tuple_expr_of_literals(expr: itir.Expr):
return isinstance(expr, itir.Literal)


def _inline_as_fieldop_arg(
arg: itir.Expr, *, uids: eve_utils.UIDGenerator
) -> tuple[itir.Expr, dict[str, itir.Expr]]:
assert cpm.is_applied_as_fieldop(arg)
arg = _canonicalize_as_fieldop(arg)

stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop`
inner_args: list[itir.Expr] = arg.args
extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg

stencil_params: list[itir.Sym] = []
stencil_body: itir.Expr = stencil.expr

for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True):
if isinstance(inner_arg, itir.SymRef):
stencil_params.append(inner_param)
extracted_args[inner_arg.id] = inner_arg
elif isinstance(inner_arg, itir.Literal):
# note: only literals, not all scalar expressions are required as it doesn't make sense
# for them to be computed per grid point.
stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))(
stencil_body
)
else:
# a scalar expression, a previously not inlined `as_fieldop` call or an opaque
# expression e.g. containing a tuple
stencil_params.append(inner_param)
new_outer_stencil_param = uids.sequential_id(prefix="__iasfop")
extracted_args[new_outer_stencil_param] = inner_arg

return im.lift(im.lambda_(*stencil_params)(stencil_body))(
*extracted_args.keys()
), extracted_args


def fuse_as_fieldop(
expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator
) -> itir.Expr:
assert cpm.is_applied_as_fieldop(expr) and isinstance(expr.fun.args[0], itir.Lambda) # type: ignore[attr-defined] # ensured by is_applied_as_fieldop

stencil: itir.Lambda = expr.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop
domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by is_applied_as_fieldop

args: list[itir.Expr] = expr.args

new_args: dict[str, itir.Expr] = {}
new_stencil_body: itir.Expr = stencil.expr

for eligible, stencil_param, arg in zip(eligible_args, stencil.params, args, strict=True):
if eligible:
if cpm.is_applied_as_fieldop(arg):
pass
elif cpm.is_call_to(arg, "if_"):
# TODO(tehrengruber): revisit if we want to inline if_
type_ = arg.type
arg = im.op_as_fieldop("if_")(*arg.args)
arg.type = type_
elif _is_tuple_expr_of_literals(arg):
arg = im.op_as_fieldop(im.lambda_()(arg))()
else:
raise NotImplementedError()

inline_expr, extracted_args = _inline_as_fieldop_arg(arg, uids=uids)

new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body)

new_args = _merge_arguments(new_args, extracted_args)
else:
# just a safety check if typing information is available
if arg.type and not isinstance(arg.type, ts.DeferredType):
assert isinstance(arg.type, ts.TypeSpec)
dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type)
assert not isinstance(dtype, it_ts.ListType)
new_param: str
if isinstance(
arg, itir.SymRef
): # use name from outer scope (optional, just to get a nice IR)
new_param = arg.id
new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body)
else:
new_param = stencil_param.id
new_args = _merge_arguments(new_args, {new_param: arg})

new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)(
*new_args.values()
)

# simplify stencil directly to keep the tree small
new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply(
new_node
) # to keep the tree small
new_node = inline_lambdas.InlineLambdas.apply(
new_node, opcount_preserving=True, force_inline_lift_args=True
)
new_node = inline_lifts.InlineLifts().visit(new_node)

type_inference.copy_type(from_=expr, to=new_node, allow_untyped=True)

return new_node


@dataclasses.dataclass
class FuseAsFieldOp(eve.NodeTranslator):
"""
Expand Down Expand Up @@ -98,38 +199,6 @@ class FuseAsFieldOp(eve.NodeTranslator):

uids: eve_utils.UIDGenerator

def _inline_as_fieldop_arg(self, arg: itir.Expr) -> tuple[itir.Expr, dict[str, itir.Expr]]:
assert cpm.is_applied_as_fieldop(arg)
arg = _canonicalize_as_fieldop(arg)

stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop`
inner_args: list[itir.Expr] = arg.args
extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg

stencil_params: list[itir.Sym] = []
stencil_body: itir.Expr = stencil.expr

for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True):
if isinstance(inner_arg, itir.SymRef):
stencil_params.append(inner_param)
extracted_args[inner_arg.id] = inner_arg
elif isinstance(inner_arg, itir.Literal):
# note: only literals, not all scalar expressions are required as it doesn't make sense
# for them to be computed per grid point.
stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))(
stencil_body
)
else:
# a scalar expression, a previously not inlined `as_fieldop` call or an opaque
# expression e.g. containing a tuple
stencil_params.append(inner_param)
new_outer_stencil_param = self.uids.sequential_id(prefix="__iasfop")
extracted_args[new_outer_stencil_param] = inner_arg

return im.lift(im.lambda_(*stencil_params)(stencil_body))(
*extracted_args.keys()
), extracted_args

@classmethod
def apply(
cls,
Expand Down Expand Up @@ -158,72 +227,26 @@ def visit_FunCall(self, node: itir.FunCall):

if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda):
stencil: itir.Lambda = node.fun.args[0]
domain = node.fun.args[1] if len(node.fun.args) > 1 else None

shifts = trace_shifts.trace_stencil(stencil)

args: list[itir.Expr] = node.args
shifts = trace_shifts.trace_stencil(stencil)

new_args: dict[str, itir.Expr] = {}
new_stencil_body: itir.Expr = stencil.expr

for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True):
eligible_args = []
for arg, arg_shifts in zip(args, shifts, strict=True):
assert isinstance(arg.type, ts.TypeSpec)
dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type)
# TODO(tehrengruber): make this configurable
should_inline = _is_tuple_expr_of_literals(arg) or (
isinstance(arg, itir.FunCall)
and (
cpm.is_call_to(arg.fun, "as_fieldop")
and isinstance(arg.fun.args[0], itir.Lambda)
or cpm.is_call_to(arg, "if_")
eligible_args.append(
_is_tuple_expr_of_literals(arg)
or (
isinstance(arg, itir.FunCall)
and (
cpm.is_call_to(arg.fun, "as_fieldop")
and isinstance(arg.fun.args[0], itir.Lambda)
or cpm.is_call_to(arg, "if_")
)
and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1)
)
and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1)
)
if should_inline:
if cpm.is_applied_as_fieldop(arg):
pass
elif cpm.is_call_to(arg, "if_"):
# TODO(tehrengruber): revisit if we want to inline if_
type_ = arg.type
arg = im.op_as_fieldop("if_")(*arg.args)
arg.type = type_
elif _is_tuple_expr_of_literals(arg):
arg = im.op_as_fieldop(im.lambda_()(arg))()
else:
raise NotImplementedError()

inline_expr, extracted_args = self._inline_as_fieldop_arg(arg)

new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body)

new_args = _merge_arguments(new_args, extracted_args)
else:
assert not isinstance(dtype, it_ts.ListType)
new_param: str
if isinstance(
arg, itir.SymRef
): # use name from outer scope (optional, just to get a nice IR)
new_param = arg.id
new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body)
else:
new_param = stencil_param.id
new_args = _merge_arguments(new_args, {new_param: arg})

new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)(
*new_args.values()
)

# simplify stencil directly to keep the tree small
new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply(
new_node
) # to keep the tree small
new_node = inline_lambdas.InlineLambdas.apply(
new_node, opcount_preserving=True, force_inline_lift_args=True
)
new_node = inline_lifts.InlineLifts().visit(new_node)

type_inference.copy_type(from_=node, to=new_node)

return new_node
return fuse_as_fieldop(node, eligible_args, uids=self.uids)
return node
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _transform_by_pattern(
# or a tuple thereof)
# - one `SetAt` statement that materializes the expression into the temporary
for tmp_sym, tmp_expr in extracted_fields.items():
domain = tmp_expr.annex.domain
domain: infer_domain.DomainAccess = tmp_expr.annex.domain

# TODO(tehrengruber): Implement. This happens when the expression is a combination
# of an `if_` call with a tuple, e.g., `if_(cond, {a, b}, {c, d})`. As long as we are
Expand Down Expand Up @@ -186,7 +186,7 @@ def create_global_tmps(
This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its
arguments into temporaries.
"""
program = infer_domain.infer_program(program, offset_provider)
program = infer_domain.infer_program(program, offset_provider=offset_provider)
program = type_inference.infer(
program, offset_provider_type=common.offset_provider_to_type(offset_provider)
)
Expand Down
Loading

0 comments on commit 59e0ed5

Please sign in to comment.