From 120a37e674fd67ef99dd94d09871f7a25a06fdc9 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 18 Sep 2023 11:52:36 +0200 Subject: [PATCH 1/8] Add temporary extraction heuristics --- .../next/iterator/transforms/global_tmps.py | 109 ++++++++++++++---- .../next/iterator/transforms/pass_manager.py | 46 ++++++-- src/gt4py/next/iterator/type_inference.py | 4 + .../codegens/gtfn/gtfn_backend.py | 1 + .../codegens/gtfn/gtfn_module.py | 10 +- .../program_processors/runners/gtfn_cpu.py | 7 +- 6 files changed, 136 insertions(+), 41 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index e1b697e0bc..5a0f97d32f 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -15,7 +15,7 @@ import dataclasses import functools from collections.abc import Mapping -from typing import Any, Final, Iterable, Literal, Optional, Sequence +from typing import Any, Callable, Final, Iterable, Literal, Optional, Sequence import gt4py.eve as eve import gt4py.next as gtx @@ -148,20 +148,49 @@ def canonicalize_applied_lift(closure_params: list[str], node: ir.FunCall) -> ir return node -def temporary_extraction_predicate(expr: ir.Node, num_occurences: int) -> bool: - """Determine if `expr` is an applied lift that should be extracted as a temporary.""" - if not is_applied_lift(expr): - return False - # do not extract when the result is a list as we can not create temporaries for - # these stencils - if isinstance(expr.annex.type.dtype, type_inference.List): - return False - stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift` - used_symbols = collect_symbol_refs(stencil) - # do not extract when the stencil is capturing - if used_symbols: - return False - return True +@dataclasses.dataclass(frozen=True) +class TemporaryExtractionPredicate: + """ + Construct a callable that determines if a lift expr can and should be extracted to a temporary. + + The class optionally takes a heuristics that can restrict the extraction. + """ + + heuristics: Optional[Callable[[ir.Expr], bool]] = None + + def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: + """Determine if `expr` is an applied lift that should be extracted as a temporary.""" + if not is_applied_lift(expr): + return False + # do not extract when the result is a list as we can not create temporaries for + # these stencils + if isinstance(expr.annex.type.dtype, type_inference.List): + return False + if self.heuristics and not self.heuristics(expr): + return False + stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift` + used_symbols = collect_symbol_refs(stencil) + # do not extract when the stencil is capturing + if used_symbols: + return False + return True + + +@dataclasses.dataclass(frozen=True) +class SimpleTemporaryExtractionHeuristics: + """Heuristic that extracts only if a lift expr is derefed in (at most) one position.""" + + closure: ir.StencilClosure + + @functools.cached_property + def closure_shifts(self): + return trace_shifts.TraceShifts.apply(self.closure, inputs_only=False) + + def __call__(self, expr: ir.Expr) -> bool: + shifts = self.closure_shifts[id(expr)] + if len(shifts) <= 1: + return False + return True def _closure_parameter_argument_mapping(closure: ir.StencilClosure): @@ -191,7 +220,14 @@ def _ensure_expr_does_not_capture(expr: ir.Expr, whitelist: list[ir.Sym]) -> Non assert not (set(used_symbol_refs) - {param.id for param in whitelist}) -def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemporaries: +def split_closures( + node: ir.FencilDefinition, + offset_provider, + *, + extraction_heuristics: Optional[ + Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] + ] = None, +) -> FencilWithTemporaries: """Split closures on lifted function calls and introduce new temporary buffers for return values. Newly introduced temporaries will have the symbolic size of `AUTO_DOMAIN`. A symbol with the @@ -203,6 +239,13 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp 3. Extract lifted function class as new closures with the previously created temporary as output. The closures are processed in reverse order to properly respect the dependencies. """ + if not extraction_heuristics: + # extract all (eligible) lifts + def always_extract_heuristics(_): + return lambda _: True + + extraction_heuristics = always_extract_heuristics + uid_gen_tmps = UIDGenerator(prefix="_tmp") type_inference.infer_all(node, offset_provider=offset_provider, save_to_annex=True) @@ -226,9 +269,13 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan ) + extraction_predicate = TemporaryExtractionPredicate( + extraction_heuristics(current_closure) + ) + stencil_body, extracted_lifts, _ = extract_subexpression( current_closure_stencil.expr, - temporary_extraction_predicate, + extraction_predicate, uid_gen_tmps, once_only=True, deepest_expr_first=True, @@ -483,19 +530,25 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An nbt_provider = offset_provider[offset_name] old_axis = nbt_provider.origin_axis.value new_axis = nbt_provider.neighbor_axis.value - consumed_domain.ranges.pop(old_axis) - assert new_axis not in consumed_domain.ranges - consumed_domain.ranges[new_axis] = SymbolicRange( + new_range = SymbolicRange( im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN), ) + consumed_domain.ranges = dict( + (axis, range_) if axis != old_axis else (new_axis, new_range) + for axis, range_ in consumed_domain.ranges.items() + ) else: - raise NotImplementedError + raise NotImplementedError() consumed_domains.append(consumed_domain) # compute the bounds of all consumed domains if consumed_domains: - domains[param] = domain_union(consumed_domains).as_expr() + if all( + consumed_domain.ranges.keys() == consumed_domains[0].ranges.keys() + for consumed_domain in consumed_domains + ): # scalar otherwise + domains[param] = domain_union(consumed_domains).as_expr() return FencilWithTemporaries( fencil=ir.FencilDefinition( @@ -569,10 +622,18 @@ class CreateGlobalTmps(NodeTranslator): """ def visit_FencilDefinition( - self, node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any] + self, + node: ir.FencilDefinition, + *, + offset_provider: Mapping[str, Any], + extraction_heuristics: Optional[ + Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] + ] = None, ) -> FencilWithTemporaries: # Split closures on lifted function calls and introduce temporaries - res = split_closures(node, offset_provider=offset_provider) + res = split_closures( + node, offset_provider=offset_provider, extraction_heuristics=extraction_heuristics + ) # Prune unreferences closure inputs introduced in the previous step res = PruneClosureInputs().visit(res) # Prune unused temporaries possibly introduced in the previous step diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 0ff3ec25c7..d9bfa916f3 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import enum +from typing import Callable, Optional from gt4py.next.iterator import ir from gt4py.next.iterator.transforms import simple_inline_heuristic @@ -50,8 +51,6 @@ def _inline_lifts(ir, lift_mode): return InlineLifts( flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. - | InlineLifts.Flag.INLINE_LIFTED_ARGS - # needed for UnrollReduce and lift args like `(↑(λ() → constant)` ).visit(ir) else: raise ValueError() @@ -72,6 +71,8 @@ def _inline_into_scan(ir, *, max_iter=10): return ir +# TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward +# `lift_mode` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( ir: ir.Node, *, @@ -80,6 +81,9 @@ def apply_common_transforms( unroll_reduce=False, common_subexpression_elimination=True, unconditionally_collapse_tuples=False, + temporary_extraction_heuristics: Optional[ + Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] + ] = None, ): if lift_mode is None: lift_mode = LiftMode.FORCE_INLINE @@ -114,6 +118,32 @@ def apply_common_transforms( else: raise RuntimeError("Inlining lift and lambdas did not converge.") + if lift_mode != LiftMode.FORCE_INLINE: + assert offset_provider is not None + ir = CreateGlobalTmps().visit( + ir, + offset_provider=offset_provider, + extraction_heuristics=temporary_extraction_heuristics, + ) + + for _ in range(10): + inlined = InlineLifts().visit(ir) + inlined = InlineLambdas.apply( + inlined, + opcount_preserving=True, + force_inline_lift_args=True, + ) + if inlined == ir: + break + ir = inlined + else: + raise RuntimeError("Inlining lift and lambdas did not converge.") + + # If after creating temporaries, the scan is not at the top, we inline. + # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. + # λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` + ir = _inline_into_scan(ir) + # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can # only run the unconditional version here instead of in the loop above. @@ -127,6 +157,7 @@ def apply_common_transforms( ir = FuseMaps().visit(ir) ir = CollapseListGet().visit(ir) + if unroll_reduce: for _ in range(10): unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) @@ -135,20 +166,11 @@ def apply_common_transforms( ir = unrolled ir = CollapseListGet().visit(ir) ir = NormalizeShifts().visit(ir) - ir = _inline_lifts(ir, lift_mode) + ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) ir = NormalizeShifts().visit(ir) else: raise RuntimeError("Reduction unrolling failed.") - if lift_mode != LiftMode.FORCE_INLINE: - assert offset_provider is not None - ir = CreateGlobalTmps().visit(ir, offset_provider=offset_provider) - ir = InlineLifts().visit(ir) - # If after creating temporaries, the scan is not at the top, we inline. - # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. - # λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` - ir = _inline_into_scan(ir) - ir = EtaReduction().visit(ir) ir = ScanEtaReduction().visit(ir) diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 14f3e95e10..f42646b6c4 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -21,6 +21,7 @@ import gt4py.next as gtx from gt4py.next.common import Connectivity from gt4py.next.iterator import ir +from gt4py.next.iterator.transforms.global_tmps import FencilWithTemporaries from gt4py.next.type_inference import Type, TypeVar, freshen, reindex_vars, unify @@ -931,6 +932,9 @@ def visit_StencilClosure( ) return Closure(output=output, inputs=Tuple.from_elems(*inputs)) + def visit_FencilWithTemporaries(self, node: FencilWithTemporaries, **kwargs): + return self.visit(node.fencil, **kwargs) + def visit_FencilDefinition( self, node: ir.FencilDefinition, diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py index 4183f52550..1715a4ad0d 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py @@ -35,6 +35,7 @@ def _lower( offset_provider=offset_provider, unroll_reduce=do_unroll, unconditionally_collapse_tuples=True, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements + temporary_extraction_heuristics=kwargs.get("temporary_extraction_heuristics"), ) gtfn_ir = GTFN_lowering.apply( program, diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 5e24e855b5..a986c840d3 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -16,7 +16,7 @@ import dataclasses import warnings -from typing import Any, Final, Optional, TypeVar +from typing import Any, Callable, Final, Optional, TypeVar import numpy as np @@ -53,6 +53,9 @@ class GTFNTranslationStep( enable_itir_transforms: bool = True # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135 use_imperative_backend: bool = False lift_mode: Optional[LiftMode] = None + temporary_extraction_heuristics: Optional[ + Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] + ] = None def _process_regular_arguments( self, @@ -181,10 +184,10 @@ def __call__( # here and warn the user if it differs from the one configured. runtime_lift_mode = inp.kwargs.pop("lift_mode", None) lift_mode = runtime_lift_mode or self.lift_mode - if runtime_lift_mode != self.lift_mode: + if runtime_lift_mode and runtime_lift_mode != self.lift_mode: warnings.warn( f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " - "overriden to be {str(runtime_lift_mode)} at runtime." + f"overriden to be {str(runtime_lift_mode)} at runtime." ) # combine into a format that is aligned with what the backend expects @@ -202,6 +205,7 @@ def __call__( enable_itir_transforms=self.enable_itir_transforms, lift_mode=lift_mode, imperative=self.use_imperative_backend, + temporary_extraction_heuristics=self.temporary_extraction_heuristics, **inp.kwargs, ) source_code = interface.format_source( diff --git a/src/gt4py/next/program_processors/runners/gtfn_cpu.py b/src/gt4py/next/program_processors/runners/gtfn_cpu.py index 31b8323474..7939366c28 100644 --- a/src/gt4py/next/program_processors/runners/gtfn_cpu.py +++ b/src/gt4py/next/program_processors/runners/gtfn_cpu.py @@ -18,7 +18,7 @@ from gt4py.eve.utils import content_hash from gt4py.next import common -from gt4py.next.iterator.transforms import LiftMode +from gt4py.next.iterator.transforms import LiftMode, global_tmps from gt4py.next.otf import languages, recipes, stages, workflow from gt4py.next.otf.binding import cpp_interface, nanobind from gt4py.next.otf.compilation import cache, compiler @@ -134,6 +134,9 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: ]( name="run_gtfn_with_temporaries", otf_workflow=run_gtfn.otf_workflow.replace( - translation=run_gtfn.otf_workflow.translation.replace(lift_mode=LiftMode.FORCE_TEMPORARIES), + translation=run_gtfn.otf_workflow.translation.replace( + lift_mode=LiftMode.FORCE_TEMPORARIES, + temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, + ), ), ) From ace4bf9d2500ba772dd2ecd3f5790583f5194683 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 24 Sep 2023 13:50:07 +0200 Subject: [PATCH 2/8] Add test case for temporary extraction heuristics --- .../transforms_tests/test_global_tmps.py | 166 +++++++++--------- 1 file changed, 84 insertions(+), 82 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 88f6ed517b..2f8b1a4c36 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -19,6 +19,7 @@ from gt4py.next.iterator.transforms.global_tmps import ( AUTO_DOMAIN, FencilWithTemporaries, + SimpleTemporaryExtractionHeuristics, Temporary, collect_tmps_info, split_closures, @@ -31,53 +32,23 @@ def test_split_closures(): testee = ir.FencilDefinition( id="f", function_definitions=[], - params=[ir.Sym(id="d"), ir.Sym(id="inp"), ir.Sym(id="out")], + params=[im.sym("d"), im.sym("inp"), im.sym("out")], closures=[ ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="baz_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="bar_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.SymRef(id="foo_inp") - ], - ), - ) - ], - ), - args=[ir.SymRef(id="bar_inp")], - ) - ], - ), - ) - ], - ), - args=[ir.SymRef(id="baz_inp")], + domain=im.call("cartesian_domain")(), + stencil=im.lambda_("baz_inp")( + im.deref( + im.lift( + im.lambda_("bar_inp")( + im.deref( + im.lift(im.lambda_("foo_inp")(im.deref("foo_inp")))("bar_inp") + ) ) - ], - ), + )("baz_inp") + ) ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="inp")], + output=im.ref("out"), + inputs=[im.ref("inp")], ) ], ) @@ -86,54 +57,31 @@ def test_split_closures(): id="f", function_definitions=[], params=[ - ir.Sym(id="d"), - ir.Sym(id="inp"), - ir.Sym(id="out"), - ir.Sym(id="_tmp_1"), - ir.Sym(id="_tmp_2"), - ir.Sym(id="_gtmp_auto_domain"), + im.sym("d"), + im.sym("inp"), + im.sym("out"), + im.sym("_tmp_1"), + im.sym("_tmp_2"), + im.sym("_gtmp_auto_domain"), ], closures=[ ir.StencilClosure( domain=AUTO_DOMAIN, - stencil=ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="foo_inp")], - ), - ), - output=ir.SymRef(id="_tmp_2"), - inputs=[ir.SymRef(id="inp")], + stencil=im.lambda_("foo_inp")(im.deref("foo_inp")), + output=im.ref("_tmp_2"), + inputs=[im.ref("inp")], ), ir.StencilClosure( domain=AUTO_DOMAIN, - stencil=ir.Lambda( - params=[ - ir.Sym(id="bar_inp"), - ir.Sym(id="_tmp_2"), - ], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.SymRef(id="_tmp_2"), - ], - ), - ), - output=ir.SymRef(id="_tmp_1"), - inputs=[ir.SymRef(id="inp"), ir.SymRef(id="_tmp_2")], + stencil=im.lambda_("bar_inp", "_tmp_2")(im.deref("_tmp_2")), + output=im.ref("_tmp_1"), + inputs=[im.ref("inp"), im.ref("_tmp_2")], ), ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="baz_inp"), ir.Sym(id="_tmp_1")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="_tmp_1")], - ), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="inp"), ir.SymRef(id="_tmp_1")], + domain=im.call("cartesian_domain")(), + stencil=im.lambda_("baz_inp", "_tmp_1")(im.deref("_tmp_1")), + output=im.ref("out"), + inputs=[im.ref("inp"), im.ref("_tmp_1")], ), ], ) @@ -142,6 +90,60 @@ def test_split_closures(): assert actual.fencil == expected +def test_split_closures_simple_heuristics(): + UIDs.reset_sequence() + testee = ir.FencilDefinition( + id="f", + function_definitions=[], + params=[im.sym("d"), im.sym("inp"), im.sym("out")], + closures=[ + ir.StencilClosure( + domain=im.call("cartesian_domain")(), + stencil=im.lambda_("foo")( + im.let("lifted_it", im.lift(im.lambda_("bar")(im.deref("bar")))("foo"))( + im.plus(im.deref("lifted_it"), im.deref(im.shift("I", 1)("lifted_it"))) + ) + ), + output=im.ref("out"), + inputs=[im.ref("inp")], + ) + ], + ) + + expected = ir.FencilDefinition( + id="f", + function_definitions=[], + params=[ + im.sym("d"), + im.sym("inp"), + im.sym("out"), + im.sym("_tmp_1"), + im.sym("_gtmp_auto_domain"), + ], + closures=[ + ir.StencilClosure( + domain=AUTO_DOMAIN, + stencil=im.lambda_("bar")(im.deref("bar")), + output=im.ref("_tmp_1"), + inputs=[im.ref("inp")], + ), + ir.StencilClosure( + domain=im.call("cartesian_domain")(), + stencil=im.lambda_("foo", "_tmp_1")( + im.plus(im.deref("_tmp_1"), im.deref(im.shift("I", 1)("_tmp_1"))) + ), + output=im.ref("out"), + inputs=[im.ref("inp"), im.ref("_tmp_1")], + ), + ], + ) + actual = split_closures( + testee, extraction_heuristics=SimpleTemporaryExtractionHeuristics, offset_provider={} + ) + assert actual.tmps == [Temporary(id="_tmp_1")] + assert actual.fencil == expected + + def test_split_closures_lifted_scan(): UIDs.reset_sequence() From 83577d20fd57e63b0094f9afd5b39a567598dec2 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 24 Sep 2023 13:59:09 +0200 Subject: [PATCH 3/8] Fix typo --- src/gt4py/next/iterator/transforms/global_tmps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 5a0f97d32f..0ba1312803 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -153,7 +153,7 @@ class TemporaryExtractionPredicate: """ Construct a callable that determines if a lift expr can and should be extracted to a temporary. - The class optionally takes a heuristics that can restrict the extraction. + The class optionally takes a heuristic that can restrict the extraction. """ heuristics: Optional[Callable[[ir.Expr], bool]] = None From 979408cd668fa9e073d913885cc978a030ae87f5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 25 Oct 2023 12:18:24 +0200 Subject: [PATCH 4/8] Small fixes --- .../next/iterator/transforms/global_tmps.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 0ba1312803..8c553d79fa 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -178,7 +178,7 @@ def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: @dataclasses.dataclass(frozen=True) class SimpleTemporaryExtractionHeuristics: - """Heuristic that extracts only if a lift expr is derefed in (at most) one position.""" + """Heuristic that extracts only if a lift expr is derefed in one position.""" closure: ir.StencilClosure @@ -188,9 +188,13 @@ def closure_shifts(self): def __call__(self, expr: ir.Expr) -> bool: shifts = self.closure_shifts[id(expr)] - if len(shifts) <= 1: - return False - return True + # Lift expressions that are never dereferenced are not extracted as we can not deduce + # a domain for them (and thus can not generate a temporary). These expressions only occur + # in combination with the scan pass (as they are otherwise removed earlier by the lift + # and lambda inliner) and are removed later using the scan inliner. + if len(shifts) == 1: + return True + return False def _closure_parameter_argument_mapping(closure: ir.StencilClosure): @@ -492,7 +496,12 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An if closure.domain == AUTO_DOMAIN: # every closure with auto domain should have a single out field assert isinstance(closure.output, ir.SymRef) + + if closure.output.id not in domains: + raise NotImplementedError(f"Closure output {closure.output.id} is never used.") + domain = domains[closure.output.id] + closure = ir.StencilClosure( domain=copy.deepcopy(domain), stencil=closure.stencil, @@ -504,14 +513,6 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An closures.append(closure) - if closure.stencil == ir.SymRef(id="deref"): - # all closure inputs inherit the domain - for input_arg in _tuple_constituents(closure.inputs[0]): - assert isinstance(input_arg, ir.SymRef) - assert domains.get(input_arg.id, domain) == domain - domains[input_arg.id] = domain - continue - local_shifts = trace_shifts.TraceShifts.apply(closure) for param, shift_chains in local_shifts.items(): assert isinstance(param, str) From 73e586648da287fc74ac0ab8d0ed2c4df9a49384 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 26 Oct 2023 11:29:35 +0200 Subject: [PATCH 5/8] Small fixes --- src/gt4py/next/iterator/transforms/global_tmps.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 8c553d79fa..b5beeebf76 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -162,15 +162,15 @@ def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: """Determine if `expr` is an applied lift that should be extracted as a temporary.""" if not is_applied_lift(expr): return False - # do not extract when the result is a list as we can not create temporaries for - # these stencils + # do not extract when the result is a list (i.e. a lift expression used in a `reduce` call) + # as we can not create temporaries for these stencils if isinstance(expr.annex.type.dtype, type_inference.List): return False if self.heuristics and not self.heuristics(expr): return False stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift` - used_symbols = collect_symbol_refs(stencil) # do not extract when the stencil is capturing + used_symbols = collect_symbol_refs(stencil) if used_symbols: return False return True @@ -178,7 +178,7 @@ def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: @dataclasses.dataclass(frozen=True) class SimpleTemporaryExtractionHeuristics: - """Heuristic that extracts only if a lift expr is derefed in one position.""" + """Heuristic that extracts only if a lift expr is derefed in more than one position.""" closure: ir.StencilClosure @@ -188,11 +188,7 @@ def closure_shifts(self): def __call__(self, expr: ir.Expr) -> bool: shifts = self.closure_shifts[id(expr)] - # Lift expressions that are never dereferenced are not extracted as we can not deduce - # a domain for them (and thus can not generate a temporary). These expressions only occur - # in combination with the scan pass (as they are otherwise removed earlier by the lift - # and lambda inliner) and are removed later using the scan inliner. - if len(shifts) == 1: + if len(shifts) > 1: return True return False From b739de56013537007d0a09c4c2cca285527f47e7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Jan 2024 14:52:16 +0100 Subject: [PATCH 6/8] Small cleanup --- src/gt4py/next/iterator/transforms/global_tmps.py | 3 ++- src/gt4py/next/iterator/transforms/pass_manager.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index ba9c5d8aa6..746f094dd4 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -495,7 +495,7 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An assert isinstance(closure.output, ir.SymRef) if closure.output.id not in domains: - raise NotImplementedError(f"Closure output {closure.output.id} is never used.") + raise NotImplementedError(f"Closure output '{closure.output.id}' is never used.") domain = domains[closure.output.id] @@ -532,6 +532,7 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN), ) + # TODO(tehrengruber): Revisit. Somehow the order matters so preserve it. consumed_domain.ranges = dict( (axis, range_) if axis != old_axis else (new_axis, new_range) for axis, range_ in consumed_domain.ranges.items() diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index b13a312ce2..9b9a11548f 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -142,7 +142,7 @@ def apply_common_transforms( break ir = inlined else: - raise RuntimeError("Inlining lift and lambdas did not converge.") + raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") # If after creating temporaries, the scan is not at the top, we inline. # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. From a95acea4549910ec8309ad53a8c7f454ba633742 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Jan 2024 14:56:21 +0100 Subject: [PATCH 7/8] Small cleanup --- src/gt4py/next/iterator/transforms/global_tmps.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 746f094dd4..ac1c01327b 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -179,7 +179,12 @@ def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: @dataclasses.dataclass(frozen=True) class SimpleTemporaryExtractionHeuristics: - """Heuristic that extracts only if a lift expr is derefed in more than one position.""" + """ + Heuristic that extracts only if a lift expr is derefed in more than one position. + + Note that such expression result in redundant computations if inlined instead of being + placed into a temporary. + """ closure: ir.StencilClosure From 73a65f12549563f0083a9edb248b55da13f5a2af Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 8 Feb 2024 10:40:01 +0100 Subject: [PATCH 8/8] Fix broken CI --- .../feature_tests/ffront_tests/test_execution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 7fc2d82e67..ae5e434085 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -622,7 +622,7 @@ def simple_scan_operator(carry: float) -> float: @pytest.mark.uses_lift_expressions @pytest.mark.uses_scan_nested def test_solve_triag(cartesian_case): - if cartesian_case.executor == gtfn.run_gtfn_with_temporaries: + if cartesian_case.executor == gtfn.run_gtfn_with_temporaries.executor: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0))