Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next]: Temporary extraction heuristics #1341

Merged

Conversation

tehrengruber
Copy link
Contributor

@tehrengruber tehrengruber commented Sep 18, 2023

Adds a heuristics that only extracts a temporary if the respective lift expr is derefed in more than one position. This should give reasonably good performance and avoids many unnecessary temporaries.

@@ -31,53 +32,23 @@ def test_split_closures():
testee = ir.FencilDefinition(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes in this function are merely a refactoring to use ir.makers.

@@ -134,6 +134,9 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int:
](
name="run_gtfn_with_temporaries",
otf_workflow=run_gtfn.otf_workflow.replace(
translation=run_gtfn.otf_workflow.translation.replace(lift_mode=LiftMode.FORCE_TEMPORARIES),
translation=run_gtfn.otf_workflow.translation.replace(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporary extraction without the heuristics is not useful in almost all cases. I just switched to using the heuristics by default.

@@ -483,19 +530,25 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An
nbt_provider = offset_provider[offset_name]
old_axis = nbt_provider.origin_axis.value
new_axis = nbt_provider.neighbor_axis.value
consumed_domain.ranges.pop(old_axis)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR. This approach of popping from the dict failed when the domain had more than one axis as the order was not preserved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does the order matter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I think about it, I got to admit I don't know. With the embedded backend it shouldn't be a problem, but with gtfn it did not work, but just silently resulted in all values being zero. Not sure what's the best way to proceed, shall I just create an issue to investigate?

@tehrengruber tehrengruber requested a review from havogt September 24, 2023 11:59
src/gt4py/next/iterator/transforms/global_tmps.py Outdated Show resolved Hide resolved
Comment on lines 172 to 175
used_symbols = collect_symbol_refs(stencil)
# do not extract when the stencil is capturing
if used_symbols:
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot not comment on this, but I let you guess what my opinion is and you may ignore it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, in this case I'm ignoring it. It's just too short, both the section in question and the surrounding to put it into a function.

Comment on lines 191 to 194
# Lift expressions that are never dereferenced are not extracted as we can not deduce
# a domain for them (and thus can not generate a temporary). These expressions only occur
# in combination with the scan pass (as they are otherwise removed earlier by the lift
# and lambda inliner) and are removed later using the scan inliner.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand why this is here, probably because you discovered the scan part while you were working on the pass. My first reaction was: obviously you should not extract something that is not derefed because it means the value is not used. Not sure if you can reformulate in that direction or am I missing something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forget about this comment I confused myself. The docstring here was wrong, we only extract when the expr is dereferenced in more than one position. The problem I stumbled upon is something else and I will fix it together with the scans another time.

src/gt4py/next/iterator/transforms/global_tmps.py Outdated Show resolved Hide resolved
Comment on lines +246 to +251
if not extraction_heuristics:
# extract all (eligible) lifts
def always_extract_heuristics(_):
return lambda _: True

extraction_heuristics = always_extract_heuristics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this as default argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That wouldn't be very handy. The heuristics is passed through from the backend configuration trough the pass manager, temporary extraction pass until it ends up here. By using None we can easily just specify None in the backend configuration and it gets passed through until it is here translated into the default heuristics.

@@ -483,19 +530,25 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An
nbt_provider = offset_provider[offset_name]
old_axis = nbt_provider.origin_axis.value
new_axis = nbt_provider.neighbor_axis.value
consumed_domain.ranges.pop(old_axis)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does the order matter?

Comment on lines +935 to +937
def visit_FencilWithTemporaries(self, node: FencilWithTemporaries, **kwargs):
return self.visit(node.fencil, **kwargs)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the pass does not return nodes again, but their type. As such the generic visit would try to create a FencilWithTemporaries where fencil has then been transformed into a type and then fail with e.g.

TypeError: 'FencilWithTemporaries.fencil' must be <class 'gt4py.next.iterator.ir.FencilDefinition'> (got 'FencilDefinitionType(name='__field_operator_testee', fundefs=EmptyTuple(), params=Tuple(front=TypeVar(idx=1305), others=Tuple(front=TypeVar(idx=1314), others=Tuple(front=TypeVar(idx=1323), others=Tuple(front=TypeVar(idx=1324), others=Tuple(front=TypeVar(idx=1325), others=Tuple(front=TypeVar(idx=1326), others=Tuple(front=TypeVar(idx=1327), others=Tuple(front=TypeVar(idx=1328), others=EmptyTuple())))))))))' which is a <class 'gt4py.next.iterator.type_inference.FencilDefinitionType'>).

@tehrengruber tehrengruber requested a review from havogt November 20, 2023 13:54
Copy link
Contributor

@DropD DropD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small changes to conform to upcoming error message guidelines.

src/gt4py/next/iterator/transforms/global_tmps.py Outdated Show resolved Hide resolved
src/gt4py/next/iterator/transforms/pass_manager.py Outdated Show resolved Hide resolved
Copy link
Contributor

@havogt havogt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One comment about your TODO, maybe you want to update it.

@@ -512,13 +561,22 @@ def update_domains(
(axis, range_) if axis != old_axis else (new_axis, new_range)
for axis, range_ in consumed_domain.ranges.items()
)
# TODO(tehrengruber): Revisit. Somehow the order matters so preserve it.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for unstructured_domain(horizontal, vertical) the order matters, is that the point here?

Copy link
Contributor Author

@tehrengruber tehrengruber Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing that it is more the TemporaryAllocation where the order matters, but I am not sure. The unstructured_domain is also using TaggedValues so it should be fine.

@tehrengruber
Copy link
Contributor Author

cscs-ci run

@tehrengruber tehrengruber merged commit e24f52d into GridTools:main Feb 8, 2024
31 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants