diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index c423a3c277..4f4fd053b2 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -15,7 +15,7 @@ import dataclasses import functools from collections.abc import Mapping -from typing import Any, Final, Iterable, Literal, Optional, Sequence +from typing import Any, Callable, Final, Iterable, Literal, Optional, Sequence import gt4py.eve as eve import gt4py.next as gtx @@ -150,20 +150,54 @@ def canonicalize_applied_lift(closure_params: list[str], node: ir.FunCall) -> ir return node -def temporary_extraction_predicate(expr: ir.Node, num_occurences: int) -> bool: - """Determine if `expr` is an applied lift that should be extracted as a temporary.""" - if not is_applied_lift(expr): - return False - # do not extract when the result is a list as we can not create temporaries for - # these stencils - if isinstance(expr.annex.type.dtype, type_inference.List): - return False - stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift` - used_symbols = collect_symbol_refs(stencil) - # do not extract when the stencil is capturing - if used_symbols: +@dataclasses.dataclass(frozen=True) +class TemporaryExtractionPredicate: + """ + Construct a callable that determines if a lift expr can and should be extracted to a temporary. + + The class optionally takes a heuristic that can restrict the extraction. + """ + + heuristics: Optional[Callable[[ir.Expr], bool]] = None + + def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: + """Determine if `expr` is an applied lift that should be extracted as a temporary.""" + if not is_applied_lift(expr): + return False + # do not extract when the result is a list (i.e. a lift expression used in a `reduce` call) + # as we can not create temporaries for these stencils + if isinstance(expr.annex.type.dtype, type_inference.List): + return False + if self.heuristics and not self.heuristics(expr): + return False + stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift` + # do not extract when the stencil is capturing + used_symbols = collect_symbol_refs(stencil) + if used_symbols: + return False + return True + + +@dataclasses.dataclass(frozen=True) +class SimpleTemporaryExtractionHeuristics: + """ + Heuristic that extracts only if a lift expr is derefed in more than one position. + + Note that such expression result in redundant computations if inlined instead of being + placed into a temporary. + """ + + closure: ir.StencilClosure + + @functools.cached_property + def closure_shifts(self): + return trace_shifts.TraceShifts.apply(self.closure, inputs_only=False) + + def __call__(self, expr: ir.Expr) -> bool: + shifts = self.closure_shifts[id(expr)] + if len(shifts) > 1: + return True return False - return True def _closure_parameter_argument_mapping(closure: ir.StencilClosure): @@ -193,7 +227,14 @@ def _ensure_expr_does_not_capture(expr: ir.Expr, whitelist: list[ir.Sym]) -> Non assert not (set(used_symbol_refs) - {param.id for param in whitelist}) -def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemporaries: +def split_closures( + node: ir.FencilDefinition, + offset_provider, + *, + extraction_heuristics: Optional[ + Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] + ] = None, +) -> FencilWithTemporaries: """Split closures on lifted function calls and introduce new temporary buffers for return values. Newly introduced temporaries will have the symbolic size of `AUTO_DOMAIN`. A symbol with the @@ -205,6 +246,13 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp 3. Extract lifted function class as new closures with the previously created temporary as output. The closures are processed in reverse order to properly respect the dependencies. """ + if not extraction_heuristics: + # extract all (eligible) lifts + def always_extract_heuristics(_): + return lambda _: True + + extraction_heuristics = always_extract_heuristics + uid_gen_tmps = UIDGenerator(prefix="_tmp") type_inference.infer_all(node, offset_provider=offset_provider, save_to_annex=True) @@ -228,9 +276,13 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan ) + extraction_predicate = TemporaryExtractionPredicate( + extraction_heuristics(current_closure) + ) + stencil_body, extracted_lifts, _ = extract_subexpression( current_closure_stencil.expr, - temporary_extraction_predicate, + extraction_predicate, uid_gen_tmps, once_only=True, deepest_expr_first=True, @@ -454,7 +506,12 @@ def update_domains( if closure.domain == AUTO_DOMAIN: # every closure with auto domain should have a single out field assert isinstance(closure.output, ir.SymRef) + + if closure.output.id not in domains: + raise NotImplementedError(f"Closure output '{closure.output.id}' is never used.") + domain = domains[closure.output.id] + closure = ir.StencilClosure( domain=copy.deepcopy(domain), stencil=closure.stencil, @@ -467,14 +524,6 @@ def update_domains( closures.append(closure) - if closure.stencil == ir.SymRef(id="deref"): - # all closure inputs inherit the domain - for input_arg in _tuple_constituents(closure.inputs[0]): - assert isinstance(input_arg, ir.SymRef) - assert domains.get(input_arg.id, domain) == domain - domains[input_arg.id] = domain - continue - local_shifts = trace_shifts.TraceShifts.apply(closure) for param, shift_chains in local_shifts.items(): assert isinstance(param, str) @@ -512,13 +561,22 @@ def update_domains( (axis, range_) if axis != old_axis else (new_axis, new_range) for axis, range_ in consumed_domain.ranges.items() ) + # TODO(tehrengruber): Revisit. Somehow the order matters so preserve it. + consumed_domain.ranges = dict( + (axis, range_) if axis != old_axis else (new_axis, new_range) + for axis, range_ in consumed_domain.ranges.items() + ) else: - raise NotImplementedError + raise NotImplementedError() consumed_domains.append(consumed_domain) # compute the bounds of all consumed domains if consumed_domains: - domains[param] = domain_union(consumed_domains).as_expr() + if all( + consumed_domain.ranges.keys() == consumed_domains[0].ranges.keys() + for consumed_domain in consumed_domains + ): # scalar otherwise + domains[param] = domain_union(consumed_domains).as_expr() return FencilWithTemporaries( fencil=ir.FencilDefinition( @@ -597,10 +655,15 @@ def visit_FencilDefinition( node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any], + extraction_heuristics: Optional[ + Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] + ] = None, symbolic_sizes: Optional[dict[str, str]], ) -> FencilWithTemporaries: # Split closures on lifted function calls and introduce temporaries - res = split_closures(node, offset_provider=offset_provider) + res = split_closures( + node, offset_provider=offset_provider, extraction_heuristics=extraction_heuristics + ) # Prune unreferences closure inputs introduced in the previous step res = PruneClosureInputs().visit(res) # Prune unused temporaries possibly introduced in the previous step diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 08897861c2..fe14a8f580 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import enum -from typing import Optional +from typing import Callable, Optional from gt4py.next.iterator import ir from gt4py.next.iterator.transforms import simple_inline_heuristic @@ -51,8 +51,6 @@ def _inline_lifts(ir, lift_mode): return InlineLifts( flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. - | InlineLifts.Flag.INLINE_LIFTED_ARGS - # needed for UnrollReduce and lift args like `(↑(λ() → constant)` ).visit(ir) else: raise ValueError() @@ -73,6 +71,8 @@ def _inline_into_scan(ir, *, max_iter=10): return ir +# TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward +# `lift_mode` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( ir: ir.Node, *, @@ -82,6 +82,9 @@ def apply_common_transforms( common_subexpression_elimination=True, force_inline_lambda_args=False, unconditionally_collapse_tuples=False, + temporary_extraction_heuristics: Optional[ + Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] + ] = None, symbolic_domain_sizes: Optional[dict[str, str]] = None, ): if lift_mode is None: @@ -121,6 +124,33 @@ def apply_common_transforms( else: raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") + if lift_mode != LiftMode.FORCE_INLINE: + assert offset_provider is not None + ir = CreateGlobalTmps().visit( + ir, + offset_provider=offset_provider, + extraction_heuristics=temporary_extraction_heuristics, + symbolic_sizes=symbolic_domain_sizes, + ) + + for _ in range(10): + inlined = InlineLifts().visit(ir) + inlined = InlineLambdas.apply( + inlined, + opcount_preserving=True, + force_inline_lift_args=True, + ) + if inlined == ir: + break + ir = inlined + else: + raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") + + # If after creating temporaries, the scan is not at the top, we inline. + # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. + # λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` + ir = _inline_into_scan(ir) + # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can # only run the unconditional version here instead of in the loop above. @@ -134,6 +164,7 @@ def apply_common_transforms( ir = FuseMaps().visit(ir) ir = CollapseListGet().visit(ir) + if unroll_reduce: for _ in range(10): unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) @@ -142,22 +173,11 @@ def apply_common_transforms( ir = unrolled ir = CollapseListGet().visit(ir) ir = NormalizeShifts().visit(ir) - ir = _inline_lifts(ir, lift_mode) + ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) ir = NormalizeShifts().visit(ir) else: raise RuntimeError("Reduction unrolling failed.") - if lift_mode != LiftMode.FORCE_INLINE: - assert offset_provider is not None - ir = CreateGlobalTmps().visit( - ir, offset_provider=offset_provider, symbolic_sizes=symbolic_domain_sizes - ) - ir = InlineLifts().visit(ir) - # If after creating temporaries, the scan is not at the top, we inline. - # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. - # λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` - ir = _inline_into_scan(ir) - ir = EtaReduction().visit(ir) ir = ScanEtaReduction().visit(ir) diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index d65f67b266..683a57561c 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -21,6 +21,7 @@ import gt4py.next as gtx from gt4py.next.common import Connectivity from gt4py.next.iterator import ir +from gt4py.next.iterator.transforms.global_tmps import FencilWithTemporaries from gt4py.next.type_inference import Type, TypeVar, freshen, reindex_vars, unify @@ -936,6 +937,9 @@ def visit_StencilClosure( ) return Closure(output=output, inputs=Tuple.from_elems(*inputs)) + def visit_FencilWithTemporaries(self, node: FencilWithTemporaries, **kwargs): + return self.visit(node.fencil, **kwargs) + def visit_FencilDefinition( self, node: ir.FencilDefinition, diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 718fef72af..c157cdcc46 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -17,7 +17,7 @@ import dataclasses import functools import warnings -from typing import Any, Final, Optional +from typing import Any, Callable, Final, Optional import numpy as np @@ -58,6 +58,9 @@ class GTFNTranslationStep( lift_mode: Optional[LiftMode] = None device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str]] = None + temporary_extraction_heuristics: Optional[ + Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] + ] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: @@ -179,14 +182,14 @@ def _preprocess_program( self, program: itir.FencilDefinition, offset_provider: dict[str, Connectivity | Dimension], - runtime_lift_mode: Optional[LiftMode] = None, + runtime_lift_mode: Optional[LiftMode], ) -> itir.FencilDefinition: # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added # to the interface of all (or at least all of concern) backends, but instead should be # configured in the backend itself (like it is here), until then we respect the argument # here and warn the user if it differs from the one configured. lift_mode = runtime_lift_mode or self.lift_mode - if lift_mode != self.lift_mode: + if runtime_lift_mode and runtime_lift_mode != self.lift_mode: warnings.warn( f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " f"overriden to be {str(runtime_lift_mode)} at runtime." @@ -202,6 +205,7 @@ def _preprocess_program( # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements unconditionally_collapse_tuples=True, symbolic_domain_sizes=self.symbolic_domain_sizes, + temporary_extraction_heuristics=self.temporary_extraction_heuristics, ) new_program = apply_common_transforms( diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index baa45ddc0e..157c00c368 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -22,7 +22,7 @@ import gt4py.next.allocators as next_allocators from gt4py.eve.utils import content_hash from gt4py.next import common -from gt4py.next.iterator.transforms import LiftMode +from gt4py.next.iterator.transforms import LiftMode, global_tmps from gt4py.next.otf import languages, recipes, stages, step_types, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import cache, compiler @@ -187,7 +187,8 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: name="run_gtfn_with_temporaries", otf_workflow=gtfn_executor.otf_workflow.replace( translation=gtfn_executor.otf_workflow.translation.replace( - lift_mode=LiftMode.FORCE_TEMPORARIES + lift_mode=LiftMode.FORCE_TEMPORARIES, + temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, ), ), ), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 7fc2d82e67..ae5e434085 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -622,7 +622,7 @@ def simple_scan_operator(carry: float) -> float: @pytest.mark.uses_lift_expressions @pytest.mark.uses_scan_nested def test_solve_triag(cartesian_case): - if cartesian_case.executor == gtfn.run_gtfn_with_temporaries: + if cartesian_case.executor == gtfn.run_gtfn_with_temporaries.executor: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 5c2802f90c..46ca02217f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -20,6 +20,7 @@ from gt4py.next.iterator.transforms.global_tmps import ( AUTO_DOMAIN, FencilWithTemporaries, + SimpleTemporaryExtractionHeuristics, Temporary, collect_tmps_info, split_closures, @@ -32,53 +33,23 @@ def test_split_closures(): testee = ir.FencilDefinition( id="f", function_definitions=[], - params=[ir.Sym(id="d"), ir.Sym(id="inp"), ir.Sym(id="out")], + params=[im.sym("d"), im.sym("inp"), im.sym("out")], closures=[ ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="baz_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="bar_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.SymRef(id="foo_inp") - ], - ), - ) - ], - ), - args=[ir.SymRef(id="bar_inp")], - ) - ], - ), - ) - ], - ), - args=[ir.SymRef(id="baz_inp")], + domain=im.call("cartesian_domain")(), + stencil=im.lambda_("baz_inp")( + im.deref( + im.lift( + im.lambda_("bar_inp")( + im.deref( + im.lift(im.lambda_("foo_inp")(im.deref("foo_inp")))("bar_inp") + ) ) - ], - ), + )("baz_inp") + ) ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="inp")], + output=im.ref("out"), + inputs=[im.ref("inp")], ) ], ) @@ -87,54 +58,31 @@ def test_split_closures(): id="f", function_definitions=[], params=[ - ir.Sym(id="d"), - ir.Sym(id="inp"), - ir.Sym(id="out"), - ir.Sym(id="_tmp_1"), - ir.Sym(id="_tmp_2"), - ir.Sym(id="_gtmp_auto_domain"), + im.sym("d"), + im.sym("inp"), + im.sym("out"), + im.sym("_tmp_1"), + im.sym("_tmp_2"), + im.sym("_gtmp_auto_domain"), ], closures=[ ir.StencilClosure( domain=AUTO_DOMAIN, - stencil=ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="foo_inp")], - ), - ), - output=ir.SymRef(id="_tmp_2"), - inputs=[ir.SymRef(id="inp")], + stencil=im.lambda_("foo_inp")(im.deref("foo_inp")), + output=im.ref("_tmp_2"), + inputs=[im.ref("inp")], ), ir.StencilClosure( domain=AUTO_DOMAIN, - stencil=ir.Lambda( - params=[ - ir.Sym(id="bar_inp"), - ir.Sym(id="_tmp_2"), - ], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.SymRef(id="_tmp_2"), - ], - ), - ), - output=ir.SymRef(id="_tmp_1"), - inputs=[ir.SymRef(id="inp"), ir.SymRef(id="_tmp_2")], + stencil=im.lambda_("bar_inp", "_tmp_2")(im.deref("_tmp_2")), + output=im.ref("_tmp_1"), + inputs=[im.ref("inp"), im.ref("_tmp_2")], ), ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="baz_inp"), ir.Sym(id="_tmp_1")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="_tmp_1")], - ), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="inp"), ir.SymRef(id="_tmp_1")], + domain=im.call("cartesian_domain")(), + stencil=im.lambda_("baz_inp", "_tmp_1")(im.deref("_tmp_1")), + output=im.ref("out"), + inputs=[im.ref("inp"), im.ref("_tmp_1")], ), ], ) @@ -143,6 +91,60 @@ def test_split_closures(): assert actual.fencil == expected +def test_split_closures_simple_heuristics(): + UIDs.reset_sequence() + testee = ir.FencilDefinition( + id="f", + function_definitions=[], + params=[im.sym("d"), im.sym("inp"), im.sym("out")], + closures=[ + ir.StencilClosure( + domain=im.call("cartesian_domain")(), + stencil=im.lambda_("foo")( + im.let("lifted_it", im.lift(im.lambda_("bar")(im.deref("bar")))("foo"))( + im.plus(im.deref("lifted_it"), im.deref(im.shift("I", 1)("lifted_it"))) + ) + ), + output=im.ref("out"), + inputs=[im.ref("inp")], + ) + ], + ) + + expected = ir.FencilDefinition( + id="f", + function_definitions=[], + params=[ + im.sym("d"), + im.sym("inp"), + im.sym("out"), + im.sym("_tmp_1"), + im.sym("_gtmp_auto_domain"), + ], + closures=[ + ir.StencilClosure( + domain=AUTO_DOMAIN, + stencil=im.lambda_("bar")(im.deref("bar")), + output=im.ref("_tmp_1"), + inputs=[im.ref("inp")], + ), + ir.StencilClosure( + domain=im.call("cartesian_domain")(), + stencil=im.lambda_("foo", "_tmp_1")( + im.plus(im.deref("_tmp_1"), im.deref(im.shift("I", 1)("_tmp_1"))) + ), + output=im.ref("out"), + inputs=[im.ref("inp"), im.ref("_tmp_1")], + ), + ], + ) + actual = split_closures( + testee, extraction_heuristics=SimpleTemporaryExtractionHeuristics, offset_provider={} + ) + assert actual.tmps == [Temporary(id="_tmp_1")] + assert actual.fencil == expected + + def test_split_closures_lifted_scan(): UIDs.reset_sequence()