Skip to content

Commit

Permalink
feat[next]: Inline dynamic shifts (#1738)
Browse files Browse the repository at this point in the history
Dynamic shifts are not supported in the domain inference. In order to
make them work nonetheless this PR aggressively inlines all arguments to
`as_fieldop` until they contain only references to `itir.Program`
params. Additionally the domain inference is extended to tolerate such
`as_fieldop` by introducing a special domain marker that signifies a
domain is unknown.

---------

Co-authored-by: Hannes Vogt <[email protected]>
Co-authored-by: Edoardo Paone <[email protected]>
  • Loading branch information
3 people authored Dec 6, 2024
1 parent 54f176f commit ae62965
Show file tree
Hide file tree
Showing 8 changed files with 492 additions and 237 deletions.
209 changes: 116 additions & 93 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall:
if cpm.is_ref_to(stencil, "deref"):
stencil = im.lambda_("arg")(im.deref("arg"))
new_expr = im.as_fieldop(stencil, domain)(*expr.args)
type_inference.copy_type(from_=expr, to=new_expr)
type_inference.copy_type(from_=expr, to=new_expr, allow_untyped=True)

return new_expr

Expand All @@ -68,6 +68,107 @@ def _is_tuple_expr_of_literals(expr: itir.Expr):
return isinstance(expr, itir.Literal)


def _inline_as_fieldop_arg(
arg: itir.Expr, *, uids: eve_utils.UIDGenerator
) -> tuple[itir.Expr, dict[str, itir.Expr]]:
assert cpm.is_applied_as_fieldop(arg)
arg = _canonicalize_as_fieldop(arg)

stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop`
inner_args: list[itir.Expr] = arg.args
extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg

stencil_params: list[itir.Sym] = []
stencil_body: itir.Expr = stencil.expr

for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True):
if isinstance(inner_arg, itir.SymRef):
stencil_params.append(inner_param)
extracted_args[inner_arg.id] = inner_arg
elif isinstance(inner_arg, itir.Literal):
# note: only literals, not all scalar expressions are required as it doesn't make sense
# for them to be computed per grid point.
stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))(
stencil_body
)
else:
# a scalar expression, a previously not inlined `as_fieldop` call or an opaque
# expression e.g. containing a tuple
stencil_params.append(inner_param)
new_outer_stencil_param = uids.sequential_id(prefix="__iasfop")
extracted_args[new_outer_stencil_param] = inner_arg

return im.lift(im.lambda_(*stencil_params)(stencil_body))(
*extracted_args.keys()
), extracted_args


def fuse_as_fieldop(
expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator
) -> itir.Expr:
assert cpm.is_applied_as_fieldop(expr) and isinstance(expr.fun.args[0], itir.Lambda) # type: ignore[attr-defined] # ensured by is_applied_as_fieldop

stencil: itir.Lambda = expr.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop
domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by is_applied_as_fieldop

args: list[itir.Expr] = expr.args

new_args: dict[str, itir.Expr] = {}
new_stencil_body: itir.Expr = stencil.expr

for eligible, stencil_param, arg in zip(eligible_args, stencil.params, args, strict=True):
if eligible:
if cpm.is_applied_as_fieldop(arg):
pass
elif cpm.is_call_to(arg, "if_"):
# TODO(tehrengruber): revisit if we want to inline if_
type_ = arg.type
arg = im.op_as_fieldop("if_")(*arg.args)
arg.type = type_
elif _is_tuple_expr_of_literals(arg):
arg = im.op_as_fieldop(im.lambda_()(arg))()
else:
raise NotImplementedError()

inline_expr, extracted_args = _inline_as_fieldop_arg(arg, uids=uids)

new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body)

new_args = _merge_arguments(new_args, extracted_args)
else:
# just a safety check if typing information is available
if arg.type and not isinstance(arg.type, ts.DeferredType):
assert isinstance(arg.type, ts.TypeSpec)
dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type)
assert not isinstance(dtype, it_ts.ListType)
new_param: str
if isinstance(
arg, itir.SymRef
): # use name from outer scope (optional, just to get a nice IR)
new_param = arg.id
new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body)
else:
new_param = stencil_param.id
new_args = _merge_arguments(new_args, {new_param: arg})

new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)(
*new_args.values()
)

# simplify stencil directly to keep the tree small
new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply(
new_node
) # to keep the tree small
new_node = inline_lambdas.InlineLambdas.apply(
new_node, opcount_preserving=True, force_inline_lift_args=True
)
new_node = inline_lifts.InlineLifts().visit(new_node)

type_inference.copy_type(from_=expr, to=new_node, allow_untyped=True)

return new_node


@dataclasses.dataclass
class FuseAsFieldOp(eve.NodeTranslator):
"""
Expand Down Expand Up @@ -98,38 +199,6 @@ class FuseAsFieldOp(eve.NodeTranslator):

uids: eve_utils.UIDGenerator

def _inline_as_fieldop_arg(self, arg: itir.Expr) -> tuple[itir.Expr, dict[str, itir.Expr]]:
assert cpm.is_applied_as_fieldop(arg)
arg = _canonicalize_as_fieldop(arg)

stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop`
inner_args: list[itir.Expr] = arg.args
extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg

stencil_params: list[itir.Sym] = []
stencil_body: itir.Expr = stencil.expr

for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True):
if isinstance(inner_arg, itir.SymRef):
stencil_params.append(inner_param)
extracted_args[inner_arg.id] = inner_arg
elif isinstance(inner_arg, itir.Literal):
# note: only literals, not all scalar expressions are required as it doesn't make sense
# for them to be computed per grid point.
stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))(
stencil_body
)
else:
# a scalar expression, a previously not inlined `as_fieldop` call or an opaque
# expression e.g. containing a tuple
stencil_params.append(inner_param)
new_outer_stencil_param = self.uids.sequential_id(prefix="__iasfop")
extracted_args[new_outer_stencil_param] = inner_arg

return im.lift(im.lambda_(*stencil_params)(stencil_body))(
*extracted_args.keys()
), extracted_args

@classmethod
def apply(
cls,
Expand Down Expand Up @@ -158,72 +227,26 @@ def visit_FunCall(self, node: itir.FunCall):

if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda):
stencil: itir.Lambda = node.fun.args[0]
domain = node.fun.args[1] if len(node.fun.args) > 1 else None

shifts = trace_shifts.trace_stencil(stencil)

args: list[itir.Expr] = node.args
shifts = trace_shifts.trace_stencil(stencil)

new_args: dict[str, itir.Expr] = {}
new_stencil_body: itir.Expr = stencil.expr

for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True):
eligible_args = []
for arg, arg_shifts in zip(args, shifts, strict=True):
assert isinstance(arg.type, ts.TypeSpec)
dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type)
# TODO(tehrengruber): make this configurable
should_inline = _is_tuple_expr_of_literals(arg) or (
isinstance(arg, itir.FunCall)
and (
cpm.is_call_to(arg.fun, "as_fieldop")
and isinstance(arg.fun.args[0], itir.Lambda)
or cpm.is_call_to(arg, "if_")
eligible_args.append(
_is_tuple_expr_of_literals(arg)
or (
isinstance(arg, itir.FunCall)
and (
cpm.is_call_to(arg.fun, "as_fieldop")
and isinstance(arg.fun.args[0], itir.Lambda)
or cpm.is_call_to(arg, "if_")
)
and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1)
)
and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1)
)
if should_inline:
if cpm.is_applied_as_fieldop(arg):
pass
elif cpm.is_call_to(arg, "if_"):
# TODO(tehrengruber): revisit if we want to inline if_
type_ = arg.type
arg = im.op_as_fieldop("if_")(*arg.args)
arg.type = type_
elif _is_tuple_expr_of_literals(arg):
arg = im.op_as_fieldop(im.lambda_()(arg))()
else:
raise NotImplementedError()

inline_expr, extracted_args = self._inline_as_fieldop_arg(arg)

new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body)

new_args = _merge_arguments(new_args, extracted_args)
else:
assert not isinstance(dtype, it_ts.ListType)
new_param: str
if isinstance(
arg, itir.SymRef
): # use name from outer scope (optional, just to get a nice IR)
new_param = arg.id
new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body)
else:
new_param = stencil_param.id
new_args = _merge_arguments(new_args, {new_param: arg})

new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)(
*new_args.values()
)

# simplify stencil directly to keep the tree small
new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply(
new_node
) # to keep the tree small
new_node = inline_lambdas.InlineLambdas.apply(
new_node, opcount_preserving=True, force_inline_lift_args=True
)
new_node = inline_lifts.InlineLifts().visit(new_node)

type_inference.copy_type(from_=node, to=new_node)

return new_node
return fuse_as_fieldop(node, eligible_args, uids=self.uids)
return node
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _transform_by_pattern(
# or a tuple thereof)
# - one `SetAt` statement that materializes the expression into the temporary
for tmp_sym, tmp_expr in extracted_fields.items():
domain = tmp_expr.annex.domain
domain: infer_domain.DomainAccess = tmp_expr.annex.domain

# TODO(tehrengruber): Implement. This happens when the expression is a combination
# of an `if_` call with a tuple, e.g., `if_(cond, {a, b}, {c, d})`. As long as we are
Expand Down Expand Up @@ -186,7 +186,7 @@ def create_global_tmps(
This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its
arguments into temporaries.
"""
program = infer_domain.infer_program(program, offset_provider)
program = infer_domain.infer_program(program, offset_provider=offset_provider)
program = type_inference.infer(
program, offset_provider_type=common.offset_provider_to_type(offset_provider)
)
Expand Down
Loading

0 comments on commit ae62965

Please sign in to comment.