-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
types: Remove increments from SubDimension names #2257
Changes from all commits
9dc4682
da6975b
d1a73d7
0426692
f2322c5
15cb6a9
bec6aa7
52a9a6e
99c4d49
39aa594
0442dc2
3b2a7d9
b0308a1
f86178b
e95a51b
59dd3ae
b946c7c
0418235
2f64dfc
6477cd8
907d1d4
b29d087
6e3f274
66abbc5
60a5d0d
fc322c1
6eeb17b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,16 @@ | ||
from collections.abc import Iterable | ||
from itertools import groupby | ||
|
||
from sympy import sympify | ||
|
||
from devito.symbolics import retrieve_indexed, uxreplace, retrieve_dimensions | ||
from devito.symbolics import (retrieve_indexed, uxreplace, dxreplace, | ||
retrieve_dimensions) | ||
from devito.tools import Ordering, as_tuple, flatten, filter_sorted, filter_ordered | ||
from devito.ir.support import pull_dims | ||
from devito.types import Dimension, IgnoreDimSort | ||
from devito.types.basic import AbstractFunction | ||
|
||
__all__ = ['dimension_sort', 'lower_exprs'] | ||
__all__ = ['dimension_sort', 'lower_exprs', 'separate_dimensions'] | ||
|
||
|
||
def dimension_sort(expr): | ||
|
@@ -147,3 +150,51 @@ def lower_exprs(expressions, **kwargs): | |
else: | ||
assert len(processed) == 1 | ||
return processed.pop() | ||
|
||
|
||
def separate_dimensions(expressions): | ||
""" | ||
Rename Dimensions with clashing names within expressions. | ||
""" | ||
resolutions = {} | ||
count = {} # Keep track of increments on dim names | ||
processed = [] | ||
for e in expressions: | ||
# Just want dimensions which appear in the expression | ||
# Dimensions in indices | ||
dims = set().union(*[i.function.dimensions | ||
for i in retrieve_indexed(e)]) | ||
# Dimensions in conditions and ConditionalDimension parents | ||
dims = dims.union(*[pull_dims(d.condition, flag=False) | ||
for d in dims if d.is_Conditional | ||
and d.condition is not None], | ||
set(d.parent for d in dims if d.is_Conditional)) | ||
# Sort for groupby | ||
dims = sorted(dims, key=lambda x: x.name) | ||
|
||
# TODO: Needs to check for dimensions in factors too | ||
|
||
# Group dimensions by matching names | ||
groups = tuple(tuple(g) for n, g in groupby(dims, key=lambda x: x.name)) | ||
clashes = tuple(g for g in groups if len(g) > 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you only need generator no need for |
||
|
||
subs = {} | ||
for c in clashes: | ||
# Preferentially rename dims that aren't root dims | ||
ddims = tuple(d for d in c if not d.is_Root) | ||
|
||
for d in ddims: | ||
if d in resolutions.keys(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Need to make sure you don't replace a clash by another but that case gets tricky because you can't set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could always check that the incremented dimension name isn't in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also no need for .keys() |
||
subs[d] = resolutions[d] | ||
else: | ||
try: | ||
subs[d] = d._rebuild('%s%s' % (d.name, count[d.name])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use the You can plumb it down the compiler pass from the caller site -- it's one of the kwargs |
||
count[d.name] += 1 | ||
except KeyError: | ||
subs[d] = d._rebuild('%s0' % d.name) | ||
count[d.name] = 1 | ||
resolutions[d] = subs[d] | ||
|
||
processed.append(dxreplace(e, subs)) | ||
|
||
return processed |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from devito.symbolics import IntDiv, uxreplace | ||
from devito.tools import Pickable, Tag, frozendict | ||
from devito.types import Eq, Inc, ReduceMax, ReduceMin | ||
from devito.symbolics.manipulation import _dxreplace_registry | ||
|
||
__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax'] | ||
|
||
|
@@ -220,8 +221,11 @@ def writes(self): | |
def xreplace(self, rules): | ||
return LoweredEq(self.lhs.xreplace(rules), self.rhs.xreplace(rules), **self.state) | ||
|
||
def func(self, *args): | ||
return self._rebuild(*args, evaluate=False) | ||
def func(self, *args, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since it's a Reconstructable, and in particular a Pickable, I think that the following should suffice:
|
||
return self._rebuild(*args, **kwargs, evaluate=False) | ||
|
||
|
||
_dxreplace_registry.register(LoweredEq) | ||
|
||
|
||
class ClusterizedEq(IREq): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
from devito.tools import (Ordering, Stamp, as_list, as_tuple, filter_ordered, | ||
flatten, frozendict, is_integer, toposort) | ||
from devito.types import Dimension, ModuloDimension | ||
from devito.symbolics.manipulation import _uxreplace_dispatch | ||
|
||
__all__ = ['NullInterval', 'Interval', 'IntervalGroup', 'IterationSpace', | ||
'IterationInterval', 'DataSpace', 'Forward', 'Backward', 'Any', | ||
|
@@ -1000,4 +1001,43 @@ def nonderived_directions(self): | |
return {k: v for k, v in self.directions.items() if not k.is_Derived} | ||
|
||
|
||
@_uxreplace_dispatch.register(Interval) | ||
def _(expr, rule, mode='ux'): | ||
changed = False | ||
dim, flag = _uxreplace_dispatch(expr.dim, rule, mode=mode) | ||
changed |= flag | ||
lower, flag = _uxreplace_dispatch(expr.lower, rule, mode=mode) | ||
changed |= flag | ||
upper, flag = _uxreplace_dispatch(expr.upper, rule, mode=mode) | ||
changed |= flag | ||
|
||
return Interval(dim, lower=lower, upper=upper, stamp=expr.stamp), changed | ||
|
||
|
||
@_uxreplace_dispatch.register(IterationSpace) | ||
def _(expr, rule, mode='ux'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but most importantly, see comments on slack |
||
intervals = [] | ||
relations = set() | ||
changed = False | ||
for interval in expr.intervals: | ||
i, flag = _uxreplace_dispatch(interval, rule, mode=mode) | ||
intervals.append(i) | ||
changed |= flag | ||
for relation in expr.intervals.relations: | ||
r, flag = _uxreplace_dispatch(relation, rule, mode=mode) | ||
relations.add(r) | ||
changed |= flag | ||
si, flag = _uxreplace_dispatch(expr.sub_iterators, rule, mode=mode) | ||
changed |= flag | ||
di, flag = _uxreplace_dispatch(expr.directions, rule, mode=mode) | ||
changed |= flag | ||
|
||
intervals = sorted(intervals, key=lambda x: x.dim.name) | ||
intervals = IntervalGroup(intervals, relations=frozenset(relations), | ||
mode=expr.intervals.mode) | ||
|
||
return IterationSpace(intervals, sub_iterators=si, | ||
directions=di), changed | ||
|
||
|
||
null_ispace = IterationSpace([]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the whole thing seems to be a bit intricated
retrieve_dimensions(e, deep=True) + e.implicit_dims
should give you all dimensions.Also why are we ignoring the conditionals with a factor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is that
retrieve_dimensions
collects too many things, not too few (even with flag=False). It wants to retrieve only indices in the equation, conditional dimension parent, and dimensions in the conditional dimension condition. Good point on the factor, that should probably be added in.Fabio pointed out that
separate_equations
probably wants to be after creation ofLoweredEQ
, so may need to rethink this anyway.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would a dimsion use as a free symbol rather than an index be ignored? It doesn't change the issue it's gonna be the wrong value if not changed. Any time both the conditional and the parent are used together irrespective of how is a clash.