-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
types: Remove increments from SubDimension names #2257
Changes from 19 commits
9dc4682
da6975b
d1a73d7
0426692
f2322c5
15cb6a9
bec6aa7
52a9a6e
99c4d49
39aa594
0442dc2
3b2a7d9
b0308a1
f86178b
e95a51b
59dd3ae
b946c7c
0418235
2f64dfc
6477cd8
907d1d4
b29d087
6e3f274
66abbc5
60a5d0d
fc322c1
6eeb17b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,16 @@ | ||
from collections.abc import Iterable | ||
from itertools import groupby | ||
|
||
from sympy import sympify | ||
|
||
from devito.symbolics import retrieve_indexed, uxreplace, retrieve_dimensions | ||
from devito.symbolics import (retrieve_indexed, uxreplace, dxreplace, | ||
retrieve_dimensions) | ||
from devito.tools import Ordering, as_tuple, flatten, filter_sorted, filter_ordered | ||
from devito.ir.support import pull_dims | ||
from devito.types import Dimension, IgnoreDimSort | ||
from devito.types.basic import AbstractFunction | ||
|
||
__all__ = ['dimension_sort', 'lower_exprs'] | ||
__all__ = ['dimension_sort', 'lower_exprs', 'separate_dimensions'] | ||
|
||
|
||
def dimension_sort(expr): | ||
|
@@ -147,3 +150,50 @@ def lower_exprs(expressions, **kwargs): | |
else: | ||
assert len(processed) == 1 | ||
return processed.pop() | ||
|
||
|
||
def separate_dimensions(expressions): | ||
""" | ||
Rename Dimensions with clashing names within expressions. | ||
""" | ||
resolutions = {} | ||
count = {} # Keep track of increments on dim names | ||
processed = [] | ||
for e in expressions: | ||
# Just want dimensions which appear in the expression | ||
# Dimensions in indices | ||
dims = set().union(*tuple(set(i.function.dimensions) | ||
for i in retrieve_indexed(e))) | ||
# Dimensions in conditions and ConditionalDimension parents | ||
dims = dims.union(*tuple(pull_dims(d.condition, flag=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you probably want to unclutter here and move most of this logic into ideally, all you need to do here would be:
because here you're performing a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIrc, what's required here is more picky than There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same, no need for tuple |
||
for d in dims if d.is_Conditional | ||
and d.condition is not None), | ||
set(d.parent for d in dims if d.is_Conditional)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the whole thing seems to be a bit intricated There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue is that Fabio pointed out that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Why would a dimsion use as a free symbol rather than an index be ignored? It doesn't change the issue it's gonna be the wrong value if not changed. Any time both the conditional and the parent are used together irrespective of how is a clash. |
||
# Sort for groupby | ||
dims = sorted(dims, key=lambda x: x.name) | ||
|
||
# Group dimensions by matching names | ||
groups = tuple(tuple(g) for n, g in groupby(dims, key=lambda x: x.name)) | ||
clashes = tuple(g for g in groups if len(g) > 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you only need generator no need for |
||
|
||
subs = {} | ||
for c in clashes: | ||
# Preferentially rename dims that aren't root dims | ||
rdims = tuple(d for d in c if d.is_Root) | ||
ddims = tuple(d for d in c if not d.is_Root) | ||
|
||
for d in (ddims + rdims)[:-1]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is safe, a |
||
if d in resolutions.keys(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Need to make sure you don't replace a clash by another but that case gets tricky because you can't set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could always check that the incremented dimension name isn't in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also no need for .keys() |
||
subs[d] = resolutions[d] | ||
else: | ||
try: | ||
subs[d] = d._rebuild(d.name+str(count[d.name])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. devito standard is formating |
||
count[d.name] += 1 | ||
except KeyError: | ||
subs[d] = d._rebuild(d.name+'0') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same, |
||
count[d.name] = 1 | ||
resolutions[d] = subs[d] | ||
|
||
processed.append(dxreplace(e, subs)) | ||
|
||
return processed |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,15 +9,17 @@ | |
from devito.symbolics.extended_sympy import DefFunction, rfunc | ||
from devito.symbolics.queries import q_leaf | ||
from devito.symbolics.search import retrieve_indexed, retrieve_functions | ||
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure | ||
from devito.tools import (as_list, as_tuple, flatten, split, transitive_closure, | ||
frozendict) | ||
from devito.types.basic import Basic | ||
from devito.types.array import ComponentAccess | ||
from devito.types.equation import Eq | ||
from devito.types.relational import Le, Lt, Gt, Ge | ||
from devito.types.relational import Le, Lt, Gt, Ge, AbstractRel | ||
from devito.types.dimension import ConditionalDimension | ||
|
||
__all__ = ['xreplace_indices', 'pow_to_mul', 'indexify', 'subs_op_args', | ||
'normalize_args', 'uxreplace', 'Uxmapper', 'reuse_if_untouched', | ||
'evalrel'] | ||
'evalrel', 'dxreplace'] | ||
|
||
|
||
def uxreplace(expr, rule): | ||
|
@@ -43,10 +45,19 @@ def uxreplace(expr, rule): | |
Finally, `uxreplace` supports Reconstructable objects, that is, it searches | ||
for replacement opportunities inside the Reconstructable's `__rkwargs__`. | ||
""" | ||
return _uxreplace(expr, rule)[0] | ||
return _uxreplace(expr, rule, mode='ux')[0] | ||
|
||
|
||
def _uxreplace(expr, rule): | ||
def dxreplace(expr, rule): | ||
""" | ||
As `uxreplace`, albeit with systematic replacement of dimensions in the | ||
expression, including those contained within attached `ConditionalDimension`s | ||
and `SubDomain`s, which would not be touched by the standard `uxreplace`. | ||
""" | ||
return _uxreplace(expr, rule, mode='dx')[0] | ||
|
||
|
||
def _uxreplace(expr, rule, mode='ux'): | ||
if expr in rule: | ||
v = rule[expr] | ||
if not isinstance(v, dict): | ||
|
@@ -69,53 +80,85 @@ def _uxreplace(expr, rule): | |
changed = False | ||
|
||
if rule: | ||
eargs, flag = _uxreplace_dispatch(eargs, rule) | ||
eargs, flag = _uxreplace_dispatch(eargs, rule, mode=mode) | ||
args.extend(eargs) | ||
|
||
changed |= flag | ||
|
||
# If a Reconstructable object, we need to parse the kwargs as well | ||
if _uxreplace_registry.dispatchable(expr): | ||
v = {i: getattr(expr, i) for i in expr.__rkwargs__} | ||
kwargs, flag = _uxreplace_dispatch(v, rule) | ||
# Select the correct registry for the replacement mode specified | ||
try: | ||
registry = _replace_registries[mode] | ||
except KeyError: | ||
raise ValueError("Mode '%s' does not match any replacement mode" | ||
% mode) | ||
|
||
# If a Reconstructable object, we need to parse args and kwargs | ||
if registry.dispatchable(expr): | ||
if not args: | ||
try: | ||
v = [getattr(expr, i) for i in expr.__rargs__] | ||
except AttributeError: | ||
# Reconstructable has no required args | ||
v = [] | ||
args, aflag = _uxreplace_dispatch(v, rule, mode=mode) | ||
else: | ||
aflag = False | ||
|
||
try: | ||
# Get the underlying attribute originally used to build | ||
# the object rather than the property, as the latter can | ||
# sometimes have default values which do not match the | ||
# default for the class | ||
v = {i: getattr(expr, "_"+i) for i in expr.__rkwargs__} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
What if there isn't one and there is only the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not quite sure I understand? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some type have attribute such as |
||
except AttributeError: | ||
# Reconstructable has no required kwargs | ||
v = {} | ||
kwargs, kwflag = _uxreplace_dispatch(v, rule, mode=mode) | ||
else: | ||
kwargs, flag = {}, False | ||
aflag = False | ||
kwargs, kwflag = {}, False | ||
flag = aflag | kwflag | ||
changed |= flag | ||
|
||
if changed: | ||
return _uxreplace_handle(expr, args, kwargs), True | ||
|
||
return expr, False | ||
|
||
|
||
@singledispatch | ||
def _uxreplace_dispatch(unknown, rule): | ||
def _uxreplace_dispatch(unknown, rule, mode='ux'): | ||
return unknown, False | ||
|
||
|
||
@_uxreplace_dispatch.register(Basic) | ||
def _(expr, rule): | ||
return _uxreplace(expr, rule) | ||
def _(expr, rule, mode='ux'): | ||
return _uxreplace(expr, rule, mode=mode) | ||
|
||
|
||
@_uxreplace_dispatch.register(AbstractRel) | ||
def _(expr, rule, mode='ux'): | ||
return _uxreplace(expr, rule, mode=mode) | ||
|
||
|
||
@_uxreplace_dispatch.register(tuple) | ||
@_uxreplace_dispatch.register(Tuple) | ||
@_uxreplace_dispatch.register(list) | ||
def _(iterable, rule): | ||
def _(iterable, rule, mode='ux'): | ||
ret = [] | ||
changed = False | ||
for a in iterable: | ||
ax, flag = _uxreplace(a, rule) | ||
ax, flag = _uxreplace(a, rule, mode=mode) | ||
ret.append(ax) | ||
changed |= flag | ||
return iterable.__class__(ret), changed | ||
|
||
|
||
@_uxreplace_dispatch.register(dict) | ||
def _(mapper, rule): | ||
def _(mapper, rule, mode='ux'): | ||
ret = {} | ||
changed = False | ||
for k, v in mapper.items(): | ||
vx, flag = _uxreplace_dispatch(v, rule) | ||
vx, flag = _uxreplace_dispatch(v, rule, mode=mode) | ||
ret[k] = vx | ||
changed |= flag | ||
return ret, changed | ||
|
@@ -180,13 +223,28 @@ def register(self, cls, rkwargs_callback_mapper=None): | |
_uxreplace_dispatch.register(kls, callback) | ||
|
||
def dispatchable(self, obj): | ||
# If not deep, ignore objects associated with deep uxreplace | ||
return isinstance(obj, tuple(self)) | ||
|
||
|
||
# Registry for default uxreplace | ||
_uxreplace_registry = UxreplaceRegistry() | ||
_uxreplace_registry.register(Eq) | ||
_uxreplace_registry.register(DefFunction) | ||
_uxreplace_registry.register(ComponentAccess) | ||
# Classes which only want uxreplacing when deep=True specified | ||
# _uxreplace_registry.register(ConditionalDimension, deep=True) | ||
|
||
# Registry for dimension uxreplace | ||
_dxreplace_registry = UxreplaceRegistry() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename as |
||
_dxreplace_registry.register(Eq) | ||
_dxreplace_registry.register(DefFunction) | ||
_dxreplace_registry.register(ComponentAccess) | ||
_dxreplace_registry.register(ConditionalDimension) | ||
|
||
# Create a dict of registries | ||
_replace_registries = frozendict({'ux': _uxreplace_registry, | ||
'dx': _dxreplace_registry}) | ||
|
||
|
||
class Uxmapper(dict): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -200,6 +200,10 @@ def is_const(self): | |
def root(self): | ||
return self | ||
|
||
@property | ||
def is_Root(self): | ||
return self == self.root | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is self.root |
||
|
||
@cached_property | ||
def bound_symbols(self): | ||
candidates = [self.symbolic_min, self.symbolic_max, self.symbolic_size, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Un-necessary extra "set" work, just
set().union(*[i.function.dimensions ...])