-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature[next]: Temporary extraction heuristics #1341
Changes from 8 commits
120a37e
ace4bf9
83577d2
fdad59b
979408c
73e5866
a92c3f1
15d7549
c183ba2
b739de5
569b6ec
a95acea
f21690c
d4f2589
73a65f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
import dataclasses | ||
import functools | ||
from collections.abc import Mapping | ||
from typing import Any, Final, Iterable, Literal, Optional, Sequence | ||
from typing import Any, Callable, Final, Iterable, Literal, Optional, Sequence | ||
|
||
import gt4py.eve as eve | ||
import gt4py.next as gtx | ||
|
@@ -148,20 +148,49 @@ def canonicalize_applied_lift(closure_params: list[str], node: ir.FunCall) -> ir | |
return node | ||
|
||
|
||
def temporary_extraction_predicate(expr: ir.Node, num_occurences: int) -> bool: | ||
"""Determine if `expr` is an applied lift that should be extracted as a temporary.""" | ||
if not is_applied_lift(expr): | ||
return False | ||
# do not extract when the result is a list as we can not create temporaries for | ||
# these stencils | ||
if isinstance(expr.annex.type.dtype, type_inference.List): | ||
return False | ||
stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift` | ||
used_symbols = collect_symbol_refs(stencil) | ||
# do not extract when the stencil is capturing | ||
if used_symbols: | ||
@dataclasses.dataclass(frozen=True) | ||
class TemporaryExtractionPredicate: | ||
""" | ||
Construct a callable that determines if a lift expr can and should be extracted to a temporary. | ||
|
||
The class optionally takes a heuristic that can restrict the extraction. | ||
""" | ||
|
||
heuristics: Optional[Callable[[ir.Expr], bool]] = None | ||
|
||
def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: | ||
"""Determine if `expr` is an applied lift that should be extracted as a temporary.""" | ||
if not is_applied_lift(expr): | ||
return False | ||
# do not extract when the result is a list (i.e. a lift expression used in a `reduce` call) | ||
# as we can not create temporaries for these stencils | ||
if isinstance(expr.annex.type.dtype, type_inference.List): | ||
return False | ||
if self.heuristics and not self.heuristics(expr): | ||
return False | ||
stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift` | ||
# do not extract when the stencil is capturing | ||
used_symbols = collect_symbol_refs(stencil) | ||
if used_symbols: | ||
return False | ||
return True | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class SimpleTemporaryExtractionHeuristics: | ||
"""Heuristic that extracts only if a lift expr is derefed in more than one position.""" | ||
|
||
closure: ir.StencilClosure | ||
|
||
@functools.cached_property | ||
def closure_shifts(self): | ||
return trace_shifts.TraceShifts.apply(self.closure, inputs_only=False) | ||
|
||
def __call__(self, expr: ir.Expr) -> bool: | ||
shifts = self.closure_shifts[id(expr)] | ||
if len(shifts) > 1: | ||
return True | ||
return False | ||
return True | ||
|
||
|
||
def _closure_parameter_argument_mapping(closure: ir.StencilClosure): | ||
|
@@ -191,7 +220,14 @@ def _ensure_expr_does_not_capture(expr: ir.Expr, whitelist: list[ir.Sym]) -> Non | |
assert not (set(used_symbol_refs) - {param.id for param in whitelist}) | ||
|
||
|
||
def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemporaries: | ||
def split_closures( | ||
node: ir.FencilDefinition, | ||
offset_provider, | ||
*, | ||
extraction_heuristics: Optional[ | ||
Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] | ||
] = None, | ||
) -> FencilWithTemporaries: | ||
"""Split closures on lifted function calls and introduce new temporary buffers for return values. | ||
|
||
Newly introduced temporaries will have the symbolic size of `AUTO_DOMAIN`. A symbol with the | ||
|
@@ -203,6 +239,13 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp | |
3. Extract lifted function class as new closures with the previously created temporary as output. | ||
The closures are processed in reverse order to properly respect the dependencies. | ||
""" | ||
if not extraction_heuristics: | ||
# extract all (eligible) lifts | ||
def always_extract_heuristics(_): | ||
return lambda _: True | ||
|
||
extraction_heuristics = always_extract_heuristics | ||
|
||
uid_gen_tmps = UIDGenerator(prefix="_tmp") | ||
|
||
type_inference.infer_all(node, offset_provider=offset_provider, save_to_annex=True) | ||
|
@@ -226,9 +269,13 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp | |
current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan | ||
) | ||
|
||
extraction_predicate = TemporaryExtractionPredicate( | ||
extraction_heuristics(current_closure) | ||
) | ||
|
||
stencil_body, extracted_lifts, _ = extract_subexpression( | ||
current_closure_stencil.expr, | ||
temporary_extraction_predicate, | ||
extraction_predicate, | ||
uid_gen_tmps, | ||
once_only=True, | ||
deepest_expr_first=True, | ||
|
@@ -445,7 +492,12 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An | |
if closure.domain == AUTO_DOMAIN: | ||
# every closure with auto domain should have a single out field | ||
assert isinstance(closure.output, ir.SymRef) | ||
|
||
if closure.output.id not in domains: | ||
raise NotImplementedError(f"Closure output {closure.output.id} is never used.") | ||
tehrengruber marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
domain = domains[closure.output.id] | ||
|
||
closure = ir.StencilClosure( | ||
domain=copy.deepcopy(domain), | ||
stencil=closure.stencil, | ||
|
@@ -457,14 +509,6 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An | |
|
||
closures.append(closure) | ||
|
||
if closure.stencil == ir.SymRef(id="deref"): | ||
# all closure inputs inherit the domain | ||
for input_arg in _tuple_constituents(closure.inputs[0]): | ||
assert isinstance(input_arg, ir.SymRef) | ||
assert domains.get(input_arg.id, domain) == domain | ||
domains[input_arg.id] = domain | ||
continue | ||
|
||
local_shifts = trace_shifts.TraceShifts.apply(closure) | ||
for param, shift_chains in local_shifts.items(): | ||
assert isinstance(param, str) | ||
|
@@ -483,19 +527,25 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An | |
nbt_provider = offset_provider[offset_name] | ||
old_axis = nbt_provider.origin_axis.value | ||
new_axis = nbt_provider.neighbor_axis.value | ||
consumed_domain.ranges.pop(old_axis) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated to this PR. This approach of popping from the dict failed when the domain had more than one axis as the order was not preserved. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does the order matter? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that I think about it, I got to admit I don't know. With the embedded backend it shouldn't be a problem, but with gtfn it did not work, but just silently resulted in all values being zero. Not sure what's the best way to proceed, shall I just create an issue to investigate? |
||
assert new_axis not in consumed_domain.ranges | ||
consumed_domain.ranges[new_axis] = SymbolicRange( | ||
new_range = SymbolicRange( | ||
im.literal("0", ir.INTEGER_INDEX_BUILTIN), | ||
im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN), | ||
) | ||
consumed_domain.ranges = dict( | ||
(axis, range_) if axis != old_axis else (new_axis, new_range) | ||
for axis, range_ in consumed_domain.ranges.items() | ||
) | ||
else: | ||
raise NotImplementedError | ||
raise NotImplementedError() | ||
consumed_domains.append(consumed_domain) | ||
|
||
# compute the bounds of all consumed domains | ||
if consumed_domains: | ||
domains[param] = domain_union(consumed_domains).as_expr() | ||
if all( | ||
consumed_domain.ranges.keys() == consumed_domains[0].ranges.keys() | ||
for consumed_domain in consumed_domains | ||
): # scalar otherwise | ||
domains[param] = domain_union(consumed_domains).as_expr() | ||
|
||
return FencilWithTemporaries( | ||
fencil=ir.FencilDefinition( | ||
|
@@ -569,10 +619,18 @@ class CreateGlobalTmps(NodeTranslator): | |
""" | ||
|
||
def visit_FencilDefinition( | ||
self, node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any] | ||
self, | ||
node: ir.FencilDefinition, | ||
*, | ||
offset_provider: Mapping[str, Any], | ||
extraction_heuristics: Optional[ | ||
Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] | ||
] = None, | ||
) -> FencilWithTemporaries: | ||
# Split closures on lifted function calls and introduce temporaries | ||
res = split_closures(node, offset_provider=offset_provider) | ||
res = split_closures( | ||
node, offset_provider=offset_provider, extraction_heuristics=extraction_heuristics | ||
) | ||
# Prune unreferences closure inputs introduced in the previous step | ||
res = PruneClosureInputs().visit(res) | ||
# Prune unused temporaries possibly introduced in the previous step | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
import gt4py.next as gtx | ||
from gt4py.next.common import Connectivity | ||
from gt4py.next.iterator import ir | ||
from gt4py.next.iterator.transforms.global_tmps import FencilWithTemporaries | ||
from gt4py.next.type_inference import Type, TypeVar, freshen, reindex_vars, unify | ||
|
||
|
||
|
@@ -931,6 +932,9 @@ def visit_StencilClosure( | |
) | ||
return Closure(output=output, inputs=Tuple.from_elems(*inputs)) | ||
|
||
def visit_FencilWithTemporaries(self, node: FencilWithTemporaries, **kwargs): | ||
return self.visit(node.fencil, **kwargs) | ||
|
||
Comment on lines
+940
to
+942
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the pass does not return nodes again, but their type. As such the generic visit would try to create a
|
||
def visit_FencilDefinition( | ||
self, | ||
node: ir.FencilDefinition, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we put this as default argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That wouldn't be very handy. The heuristics is passed through from the backend configuration trough the pass manager, temporary extraction pass until it ends up here. By using
None
we can easily just specifyNone
in the backend configuration and it gets passed through until it is here translated into the default heuristics.