Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next]: Temporary extraction heuristics #1341

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 89 additions & 31 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import dataclasses
import functools
from collections.abc import Mapping
from typing import Any, Final, Iterable, Literal, Optional, Sequence
from typing import Any, Callable, Final, Iterable, Literal, Optional, Sequence

import gt4py.eve as eve
import gt4py.next as gtx
Expand Down Expand Up @@ -148,20 +148,49 @@ def canonicalize_applied_lift(closure_params: list[str], node: ir.FunCall) -> ir
return node


def temporary_extraction_predicate(expr: ir.Node, num_occurences: int) -> bool:
"""Determine if `expr` is an applied lift that should be extracted as a temporary."""
if not is_applied_lift(expr):
return False
# do not extract when the result is a list as we can not create temporaries for
# these stencils
if isinstance(expr.annex.type.dtype, type_inference.List):
return False
stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift`
used_symbols = collect_symbol_refs(stencil)
# do not extract when the stencil is capturing
if used_symbols:
@dataclasses.dataclass(frozen=True)
class TemporaryExtractionPredicate:
"""
Construct a callable that determines if a lift expr can and should be extracted to a temporary.

The class optionally takes a heuristic that can restrict the extraction.
"""

heuristics: Optional[Callable[[ir.Expr], bool]] = None

def __call__(self, expr: ir.Expr, num_occurences: int) -> bool:
"""Determine if `expr` is an applied lift that should be extracted as a temporary."""
if not is_applied_lift(expr):
return False
# do not extract when the result is a list (i.e. a lift expression used in a `reduce` call)
# as we can not create temporaries for these stencils
if isinstance(expr.annex.type.dtype, type_inference.List):
return False
if self.heuristics and not self.heuristics(expr):
return False
stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift`
# do not extract when the stencil is capturing
used_symbols = collect_symbol_refs(stencil)
if used_symbols:
return False
return True


@dataclasses.dataclass(frozen=True)
class SimpleTemporaryExtractionHeuristics:
"""Heuristic that extracts only if a lift expr is derefed in more than one position."""

closure: ir.StencilClosure

@functools.cached_property
def closure_shifts(self):
return trace_shifts.TraceShifts.apply(self.closure, inputs_only=False)

def __call__(self, expr: ir.Expr) -> bool:
shifts = self.closure_shifts[id(expr)]
if len(shifts) > 1:
return True
return False
return True


def _closure_parameter_argument_mapping(closure: ir.StencilClosure):
Expand Down Expand Up @@ -191,7 +220,14 @@ def _ensure_expr_does_not_capture(expr: ir.Expr, whitelist: list[ir.Sym]) -> Non
assert not (set(used_symbol_refs) - {param.id for param in whitelist})


def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemporaries:
def split_closures(
node: ir.FencilDefinition,
offset_provider,
*,
extraction_heuristics: Optional[
Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]]
] = None,
) -> FencilWithTemporaries:
"""Split closures on lifted function calls and introduce new temporary buffers for return values.

Newly introduced temporaries will have the symbolic size of `AUTO_DOMAIN`. A symbol with the
Expand All @@ -203,6 +239,13 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
3. Extract lifted function class as new closures with the previously created temporary as output.
The closures are processed in reverse order to properly respect the dependencies.
"""
if not extraction_heuristics:
# extract all (eligible) lifts
def always_extract_heuristics(_):
return lambda _: True

extraction_heuristics = always_extract_heuristics
Comment on lines +249 to +254
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this as default argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That wouldn't be very handy. The heuristics is passed through from the backend configuration trough the pass manager, temporary extraction pass until it ends up here. By using None we can easily just specify None in the backend configuration and it gets passed through until it is here translated into the default heuristics.


uid_gen_tmps = UIDGenerator(prefix="_tmp")

type_inference.infer_all(node, offset_provider=offset_provider, save_to_annex=True)
Expand All @@ -226,9 +269,13 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp
current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan
)

extraction_predicate = TemporaryExtractionPredicate(
extraction_heuristics(current_closure)
)

stencil_body, extracted_lifts, _ = extract_subexpression(
current_closure_stencil.expr,
temporary_extraction_predicate,
extraction_predicate,
uid_gen_tmps,
once_only=True,
deepest_expr_first=True,
Expand Down Expand Up @@ -445,7 +492,12 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An
if closure.domain == AUTO_DOMAIN:
# every closure with auto domain should have a single out field
assert isinstance(closure.output, ir.SymRef)

if closure.output.id not in domains:
raise NotImplementedError(f"Closure output {closure.output.id} is never used.")
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved

domain = domains[closure.output.id]

closure = ir.StencilClosure(
domain=copy.deepcopy(domain),
stencil=closure.stencil,
Expand All @@ -457,14 +509,6 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An

closures.append(closure)

if closure.stencil == ir.SymRef(id="deref"):
# all closure inputs inherit the domain
for input_arg in _tuple_constituents(closure.inputs[0]):
assert isinstance(input_arg, ir.SymRef)
assert domains.get(input_arg.id, domain) == domain
domains[input_arg.id] = domain
continue

local_shifts = trace_shifts.TraceShifts.apply(closure)
for param, shift_chains in local_shifts.items():
assert isinstance(param, str)
Expand All @@ -483,19 +527,25 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An
nbt_provider = offset_provider[offset_name]
old_axis = nbt_provider.origin_axis.value
new_axis = nbt_provider.neighbor_axis.value
consumed_domain.ranges.pop(old_axis)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR. This approach of popping from the dict failed when the domain had more than one axis as the order was not preserved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does the order matter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I think about it, I got to admit I don't know. With the embedded backend it shouldn't be a problem, but with gtfn it did not work, but just silently resulted in all values being zero. Not sure what's the best way to proceed, shall I just create an issue to investigate?

assert new_axis not in consumed_domain.ranges
consumed_domain.ranges[new_axis] = SymbolicRange(
new_range = SymbolicRange(
im.literal("0", ir.INTEGER_INDEX_BUILTIN),
im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN),
)
consumed_domain.ranges = dict(
(axis, range_) if axis != old_axis else (new_axis, new_range)
for axis, range_ in consumed_domain.ranges.items()
)
else:
raise NotImplementedError
raise NotImplementedError()
consumed_domains.append(consumed_domain)

# compute the bounds of all consumed domains
if consumed_domains:
domains[param] = domain_union(consumed_domains).as_expr()
if all(
consumed_domain.ranges.keys() == consumed_domains[0].ranges.keys()
for consumed_domain in consumed_domains
): # scalar otherwise
domains[param] = domain_union(consumed_domains).as_expr()

return FencilWithTemporaries(
fencil=ir.FencilDefinition(
Expand Down Expand Up @@ -569,10 +619,18 @@ class CreateGlobalTmps(NodeTranslator):
"""

def visit_FencilDefinition(
self, node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any]
self,
node: ir.FencilDefinition,
*,
offset_provider: Mapping[str, Any],
extraction_heuristics: Optional[
Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]]
] = None,
) -> FencilWithTemporaries:
# Split closures on lifted function calls and introduce temporaries
res = split_closures(node, offset_provider=offset_provider)
res = split_closures(
node, offset_provider=offset_provider, extraction_heuristics=extraction_heuristics
)
# Prune unreferences closure inputs introduced in the previous step
res = PruneClosureInputs().visit(res)
# Prune unused temporaries possibly introduced in the previous step
Expand Down
46 changes: 34 additions & 12 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import enum
from typing import Callable, Optional

from gt4py.next.iterator import ir
from gt4py.next.iterator.transforms import simple_inline_heuristic
Expand Down Expand Up @@ -50,8 +51,6 @@ def _inline_lifts(ir, lift_mode):
return InlineLifts(
flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT
| InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet.
| InlineLifts.Flag.INLINE_LIFTED_ARGS
# needed for UnrollReduce and lift args like `(↑(λ() → constant)`
).visit(ir)
else:
raise ValueError()
Expand All @@ -72,6 +71,8 @@ def _inline_into_scan(ir, *, max_iter=10):
return ir


# TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward
# `lift_mode` and `temporary_extraction_heuristics` which is inconvenient.
def apply_common_transforms(
ir: ir.Node,
*,
Expand All @@ -80,6 +81,9 @@ def apply_common_transforms(
unroll_reduce=False,
common_subexpression_elimination=True,
unconditionally_collapse_tuples=False,
temporary_extraction_heuristics: Optional[
Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]]
] = None,
):
if lift_mode is None:
lift_mode = LiftMode.FORCE_INLINE
Expand Down Expand Up @@ -118,6 +122,32 @@ def apply_common_transforms(
else:
raise RuntimeError("Inlining lift and lambdas did not converge.")

if lift_mode != LiftMode.FORCE_INLINE:
assert offset_provider is not None
ir = CreateGlobalTmps().visit(
ir,
offset_provider=offset_provider,
extraction_heuristics=temporary_extraction_heuristics,
)

for _ in range(10):
inlined = InlineLifts().visit(ir)
inlined = InlineLambdas.apply(
inlined,
opcount_preserving=True,
force_inline_lift_args=True,
)
if inlined == ir:
break
ir = inlined
else:
raise RuntimeError("Inlining lift and lambdas did not converge.")
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved

# If after creating temporaries, the scan is not at the top, we inline.
# The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it.
# λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))`
ir = _inline_into_scan(ir)

# Since `CollapseTuple` relies on the type inference which does not support returning tuples
# larger than the number of closure outputs as given by the unconditional collapse, we can
# only run the unconditional version here instead of in the loop above.
Expand All @@ -131,6 +161,7 @@ def apply_common_transforms(

ir = FuseMaps().visit(ir)
ir = CollapseListGet().visit(ir)

if unroll_reduce:
for _ in range(10):
unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider)
Expand All @@ -139,20 +170,11 @@ def apply_common_transforms(
ir = unrolled
ir = CollapseListGet().visit(ir)
ir = NormalizeShifts().visit(ir)
ir = _inline_lifts(ir, lift_mode)
ir = _inline_lifts(ir, LiftMode.FORCE_INLINE)
ir = NormalizeShifts().visit(ir)
else:
raise RuntimeError("Reduction unrolling failed.")

if lift_mode != LiftMode.FORCE_INLINE:
assert offset_provider is not None
ir = CreateGlobalTmps().visit(ir, offset_provider=offset_provider)
ir = InlineLifts().visit(ir)
# If after creating temporaries, the scan is not at the top, we inline.
# The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it.
# λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))`
ir = _inline_into_scan(ir)

ir = EtaReduction().visit(ir)
ir = ScanEtaReduction().visit(ir)

Expand Down
4 changes: 4 additions & 0 deletions src/gt4py/next/iterator/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import gt4py.next as gtx
from gt4py.next.common import Connectivity
from gt4py.next.iterator import ir
from gt4py.next.iterator.transforms.global_tmps import FencilWithTemporaries
from gt4py.next.type_inference import Type, TypeVar, freshen, reindex_vars, unify


Expand Down Expand Up @@ -931,6 +932,9 @@ def visit_StencilClosure(
)
return Closure(output=output, inputs=Tuple.from_elems(*inputs))

def visit_FencilWithTemporaries(self, node: FencilWithTemporaries, **kwargs):
return self.visit(node.fencil, **kwargs)

Comment on lines +940 to +942
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the pass does not return nodes again, but their type. As such the generic visit would try to create a FencilWithTemporaries where fencil has then been transformed into a type and then fail with e.g.

TypeError: 'FencilWithTemporaries.fencil' must be <class 'gt4py.next.iterator.ir.FencilDefinition'> (got 'FencilDefinitionType(name='__field_operator_testee', fundefs=EmptyTuple(), params=Tuple(front=TypeVar(idx=1305), others=Tuple(front=TypeVar(idx=1314), others=Tuple(front=TypeVar(idx=1323), others=Tuple(front=TypeVar(idx=1324), others=Tuple(front=TypeVar(idx=1325), others=Tuple(front=TypeVar(idx=1326), others=Tuple(front=TypeVar(idx=1327), others=Tuple(front=TypeVar(idx=1328), others=EmptyTuple())))))))))' which is a <class 'gt4py.next.iterator.type_inference.FencilDefinitionType'>).

def visit_FencilDefinition(
self,
node: ir.FencilDefinition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _lower(
offset_provider=offset_provider,
unroll_reduce=do_unroll,
unconditionally_collapse_tuples=True, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements
temporary_extraction_heuristics=kwargs.get("temporary_extraction_heuristics"),
)
gtfn_ir = GTFN_lowering.apply(
program,
Expand Down
10 changes: 7 additions & 3 deletions src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import dataclasses
import warnings
from typing import Any, Final, Optional
from typing import Any, Callable, Final, Optional

import numpy as np

Expand Down Expand Up @@ -54,6 +54,9 @@ class GTFNTranslationStep(
use_imperative_backend: bool = False
lift_mode: Optional[LiftMode] = None
device_type: core_defs.DeviceType = core_defs.DeviceType.CPU
temporary_extraction_heuristics: Optional[
Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]]
] = None

def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings:
match self.device_type:
Expand Down Expand Up @@ -196,10 +199,10 @@ def __call__(
# here and warn the user if it differs from the one configured.
runtime_lift_mode = inp.kwargs.pop("lift_mode", None)
lift_mode = runtime_lift_mode or self.lift_mode
if runtime_lift_mode != self.lift_mode:
if runtime_lift_mode and runtime_lift_mode != self.lift_mode:
warnings.warn(
f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but "
"overriden to be {str(runtime_lift_mode)} at runtime."
f"overriden to be {str(runtime_lift_mode)} at runtime."
)

# combine into a format that is aligned with what the backend expects
Expand All @@ -218,6 +221,7 @@ def __call__(
enable_itir_transforms=self.enable_itir_transforms,
lift_mode=lift_mode,
imperative=self.use_imperative_backend,
temporary_extraction_heuristics=self.temporary_extraction_heuristics,
**inp.kwargs,
)
source_code = interface.format_source(
Expand Down
5 changes: 3 additions & 2 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import gt4py.next.allocators as next_allocators
from gt4py.eve.utils import content_hash
from gt4py.next import common
from gt4py.next.iterator.transforms import LiftMode
from gt4py.next.iterator.transforms import LiftMode, global_tmps
from gt4py.next.otf import languages, recipes, stages, step_types, workflow
from gt4py.next.otf.binding import nanobind
from gt4py.next.otf.compilation import cache, compiler
Expand Down Expand Up @@ -187,7 +187,8 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int:
name="run_gtfn_with_temporaries",
otf_workflow=gtfn_executor.otf_workflow.replace(
translation=gtfn_executor.otf_workflow.translation.replace(
lift_mode=LiftMode.FORCE_TEMPORARIES
lift_mode=LiftMode.FORCE_TEMPORARIES,
temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics,
),
),
),
Expand Down
Loading