-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature[next]: Temporary extraction heuristics #1341
feature[next]: Temporary extraction heuristics #1341
Conversation
@@ -31,53 +32,23 @@ def test_split_closures(): | |||
testee = ir.FencilDefinition( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes in this function are merely a refactoring to use ir.makers
.
@@ -134,6 +134,9 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: | |||
]( | |||
name="run_gtfn_with_temporaries", | |||
otf_workflow=run_gtfn.otf_workflow.replace( | |||
translation=run_gtfn.otf_workflow.translation.replace(lift_mode=LiftMode.FORCE_TEMPORARIES), | |||
translation=run_gtfn.otf_workflow.translation.replace( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Temporary extraction without the heuristics is not useful in almost all cases. I just switched to using the heuristics by default.
@@ -483,19 +530,25 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An | |||
nbt_provider = offset_provider[offset_name] | |||
old_axis = nbt_provider.origin_axis.value | |||
new_axis = nbt_provider.neighbor_axis.value | |||
consumed_domain.ranges.pop(old_axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated to this PR. This approach of popping from the dict failed when the domain had more than one axis as the order was not preserved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does the order matter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that I think about it, I got to admit I don't know. With the embedded backend it shouldn't be a problem, but with gtfn it did not work, but just silently resulted in all values being zero. Not sure what's the best way to proceed, shall I just create an issue to investigate?
used_symbols = collect_symbol_refs(stencil) | ||
# do not extract when the stencil is capturing | ||
if used_symbols: | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cannot not comment on this, but I let you guess what my opinion is and you may ignore it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, in this case I'm ignoring it. It's just too short, both the section in question and the surrounding to put it into a function.
# Lift expressions that are never dereferenced are not extracted as we can not deduce | ||
# a domain for them (and thus can not generate a temporary). These expressions only occur | ||
# in combination with the scan pass (as they are otherwise removed earlier by the lift | ||
# and lambda inliner) and are removed later using the scan inliner. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't fully understand why this is here, probably because you discovered the scan part while you were working on the pass. My first reaction was: obviously you should not extract something that is not derefed because it means the value is not used. Not sure if you can reformulate in that direction or am I missing something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forget about this comment I confused myself. The docstring here was wrong, we only extract when the expr is dereferenced in more than one position. The problem I stumbled upon is something else and I will fix it together with the scans another time.
if not extraction_heuristics: | ||
# extract all (eligible) lifts | ||
def always_extract_heuristics(_): | ||
return lambda _: True | ||
|
||
extraction_heuristics = always_extract_heuristics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we put this as default argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That wouldn't be very handy. The heuristics is passed through from the backend configuration trough the pass manager, temporary extraction pass until it ends up here. By using None
we can easily just specify None
in the backend configuration and it gets passed through until it is here translated into the default heuristics.
@@ -483,19 +530,25 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An | |||
nbt_provider = offset_provider[offset_name] | |||
old_axis = nbt_provider.origin_axis.value | |||
new_axis = nbt_provider.neighbor_axis.value | |||
consumed_domain.ranges.pop(old_axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does the order matter?
def visit_FencilWithTemporaries(self, node: FencilWithTemporaries, **kwargs): | ||
return self.visit(node.fencil, **kwargs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the pass does not return nodes again, but their type. As such the generic visit would try to create a FencilWithTemporaries
where fencil
has then been transformed into a type and then fail with e.g.
TypeError: 'FencilWithTemporaries.fencil' must be <class 'gt4py.next.iterator.ir.FencilDefinition'> (got 'FencilDefinitionType(name='__field_operator_testee', fundefs=EmptyTuple(), params=Tuple(front=TypeVar(idx=1305), others=Tuple(front=TypeVar(idx=1314), others=Tuple(front=TypeVar(idx=1323), others=Tuple(front=TypeVar(idx=1324), others=Tuple(front=TypeVar(idx=1325), others=Tuple(front=TypeVar(idx=1326), others=Tuple(front=TypeVar(idx=1327), others=Tuple(front=TypeVar(idx=1328), others=EmptyTuple())))))))))' which is a <class 'gt4py.next.iterator.type_inference.FencilDefinitionType'>).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small changes to conform to upcoming error message guidelines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One comment about your TODO, maybe you want to update it.
@@ -512,13 +561,22 @@ def update_domains( | |||
(axis, range_) if axis != old_axis else (new_axis, new_range) | |||
for axis, range_ in consumed_domain.ranges.items() | |||
) | |||
# TODO(tehrengruber): Revisit. Somehow the order matters so preserve it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for unstructured_domain(horizontal, vertical)
the order matters, is that the point here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing that it is more the TemporaryAllocation
where the order matters, but I am not sure. The unstructured_domain
is also using TaggedValues so it should be fine.
cscs-ci run |
Adds a heuristics that only extracts a temporary if the respective lift expr is derefed in more than one position. This should give reasonably good performance and avoids many unnecessary temporaries.